Unverified Commit b73427c9 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into fix_for_multiconfig_generators

parents 55e635e5 4c059fa3
...@@ -29,14 +29,14 @@ ...@@ -29,14 +29,14 @@
#include <migraphx/op/argmax.hpp> #include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp> #include <migraphx/op/argmin.hpp>
template <class T, int Axis, bool LastIndex, int NonStdShape> template <class T, migraphx::shape::type_t DType, int Axis, bool LastIndex, int NonStdShape>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis, LastIndex, NonStdShape>> struct test_arg_ops : verify_program<test_arg_ops<T, DType, Axis, LastIndex, NonStdShape>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 1, 4, 1025}}; migraphx::shape s{DType, {2, 1, 4, 1025}};
auto param = mm->add_parameter("data", s); auto param = mm->add_parameter("data", s);
switch(NonStdShape) switch(NonStdShape)
{ {
...@@ -59,106 +59,211 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, LastIndex, NonStdShap ...@@ -59,106 +59,211 @@ struct test_arg_ops : verify_program<test_arg_ops<T, Axis, LastIndex, NonStdShap
} }
}; };
// transpose argmax tests // transpose argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, true, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 0, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 0, false, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 0, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1, true, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 1, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1, false, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 1, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, 2, true, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 2, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 2, false, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 2, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 3, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 3, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -1, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -1, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -2, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 0>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -2, false, 0>;
// transpose argmin tests // transpose argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, true, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 0, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 0, false, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 0, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1, true, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 1, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1, false, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 1, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, 2, true, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 2, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 2, false, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 2, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 3, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 3, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -3, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -3, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -4, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 0>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -4, false, 0>;
// broadcast argmax tests // broadcast argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, true, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 0, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 0, false, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 0, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, 1, true, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 1, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 1, false, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 1, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2, true, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 2, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2, false, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 2, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 3, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 3, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -1, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -1, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -2, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 1>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -2, false, 1>;
// broadcast argmin tests // broadcast argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, true, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 0, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 0, false, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 0, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, 1, true, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 1, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 1, false, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 1, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2, true, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 2, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2, false, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 2, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 3, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 3, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -3, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -3, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -4, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 1>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -4, false, 1>;
// slice argmax tests // slice argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, true, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 0, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 0, false, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 0, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, 1, true, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 1, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 1, false, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 1, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, 2, true, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 2, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 2, false, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 2, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 3, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 3, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -1, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -1, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -2, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 2>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -2, false, 2>;
// slice argmin tests // slice argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, true, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 0, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 0, false, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 0, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, 1, true, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 1, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 1, false, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 1, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, 2, true, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 2, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 2, false, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 2, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 3, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 3, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -3, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -3, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -4, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 2>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -4, false, 2>;
// default case, standard shape argmax tests // default case, standard shape argmax tests
template struct test_arg_ops<migraphx::op::argmax, 0, true, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 0, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 0, false, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 0, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, 1, true, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 1, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 1, false, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 1, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, 2, true, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 2, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 2, false, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 2, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, true, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 3, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, 3, false, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, 3, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, true, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -1, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, -1, false, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -1, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, true, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -2, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, -2, false, 3>; template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::float_type, -2, false, 3>;
// default case, standard shape argmin tests // default case, standard shape argmin tests
template struct test_arg_ops<migraphx::op::argmin, 0, true, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 0, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 0, false, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 0, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, 1, true, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 1, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 1, false, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 1, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, 2, true, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 2, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 2, false, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 2, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, true, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 3, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, 3, false, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, 3, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, true, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -3, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, -3, false, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -3, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, true, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -4, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, -4, false, 3>; template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::float_type, -4, false, 3>;
// transpose argmax tests
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 0, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 0, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 1, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 1, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 2, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 2, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 3, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 3, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -1, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -1, false, 0>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -2, true, 0>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -2, false, 0>;
// transpose argmin tests
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 0, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 0, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 1, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 1, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 2, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 2, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 3, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 3, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -3, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -3, false, 0>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -4, true, 0>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -4, false, 0>;
// broadcast argmax tests
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 0, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 0, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 1, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 1, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 2, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 2, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 3, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 3, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -1, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -1, false, 1>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -2, true, 1>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -2, false, 1>;
// broadcast argmin tests
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 0, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 0, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 1, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 1, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 2, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 2, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 3, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 3, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -3, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -3, false, 1>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -4, true, 1>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -4, false, 1>;
// slice argmax tests
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 0, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 0, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 1, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 1, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 2, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 2, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 3, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 3, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -1, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -1, false, 2>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -2, true, 2>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -2, false, 2>;
// slice argmin tests
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 0, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 0, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 1, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 1, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 2, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 2, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 3, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 3, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -3, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -3, false, 2>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -4, true, 2>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -4, false, 2>;
// default case, standard shape argmax tests
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 0, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 0, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 1, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 1, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 2, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 2, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 3, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, 3, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -1, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -1, false, 3>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -2, true, 3>;
template struct test_arg_ops<migraphx::op::argmax, migraphx::shape::fp8e4m3fnuz_type, -2, false, 3>;
// default case, standard shape argmin tests
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 0, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 0, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 1, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 1, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 2, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 2, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 3, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, 3, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -3, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -3, false, 3>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -4, true, 3>;
template struct test_arg_ops<migraphx::op::argmin, migraphx::shape::fp8e4m3fnuz_type, -4, false, 3>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_asin : verify_program<test_asin> template <migraphx::shape::type_t DType>
struct test_asin : verify_program<test_asin<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("asin"), x); mm->add_instruction(migraphx::make_op("asin"), x);
return p; return p;
} }
}; };
template struct test_asin<migraphx::shape::float_type>;
template struct test_asin<migraphx::shape::half_type>;
template struct test_asin<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_asinh : verify_program<test_asinh> template <migraphx::shape::type_t DType>
struct test_asinh : verify_program<test_asinh<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("asinh"), x); mm->add_instruction(migraphx::make_op("asinh"), x);
return p; return p;
} }
}; };
template struct test_asinh<migraphx::shape::float_type>;
template struct test_asinh<migraphx::shape::half_type>;
template struct test_asinh<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_atan : verify_program<test_atan> template <migraphx::shape::type_t DType>
struct test_atan : verify_program<test_atan<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("atan"), x); mm->add_instruction(migraphx::make_op("atan"), x);
return p; return p;
} }
}; };
template struct test_atan<migraphx::shape::float_type>;
template struct test_atan<migraphx::shape::half_type>;
template struct test_atan<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -23,20 +23,24 @@ ...@@ -23,20 +23,24 @@
*/ */
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_atanh : verify_program<test_atanh> template <typename CType>
struct test_atanh : verify_program<test_atanh<CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape::type_t dtype = migraphx::shape::get_type<CType>();
migraphx::shape s{dtype, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(-0.95f); auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {-0.95f}});
auto max_val = mm->add_literal(0.95f); auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {0.95f}});
min_val = min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val);
max_val = max_val =
...@@ -46,3 +50,7 @@ struct test_atanh : verify_program<test_atanh> ...@@ -46,3 +50,7 @@ struct test_atanh : verify_program<test_atanh>
return p; return p;
} }
}; };
template struct test_atanh<float>;
template struct test_atanh<migraphx::half>;
template struct test_atanh<migraphx::fp8::fp8e4m3fnuz>;
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_ceil : verify_program<test_ceil> template <migraphx::shape::type_t DType>
struct test_ceil : verify_program<test_ceil<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; migraphx::shape s{DType, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s); auto param = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("ceil"), param); mm->add_instruction(migraphx::make_op("ceil"), param);
return p; return p;
}; };
}; };
template struct test_ceil<migraphx::shape::float_type>;
template struct test_ceil<migraphx::shape::half_type>;
template struct test_ceil<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,17 @@ ...@@ -27,16 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_concat_axis_0 : verify_program<test_concat_axis_0> template <migraphx::shape::type_t DType>
struct test_concat_axis_0 : verify_program<test_concat_axis_0<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
int axis = 0; int axis = 0;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{DType, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; migraphx::shape s1{DType, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s2{DType, {1, 2}};
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto l2 = mm->add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
...@@ -44,3 +45,8 @@ struct test_concat_axis_0 : verify_program<test_concat_axis_0> ...@@ -44,3 +45,8 @@ struct test_concat_axis_0 : verify_program<test_concat_axis_0>
return p; return p;
} }
}; };
template struct test_concat_axis_0<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_concat_axis_0<migraphx::shape::half_type>;
template struct test_concat_axis_0<migraphx::shape::float_type>;
template struct test_concat_axis_0<migraphx::shape::int32_type>;
...@@ -29,16 +29,20 @@ ...@@ -29,16 +29,20 @@
#include <cassert> #include <cassert>
struct test_contiguous : verify_program<test_contiguous> template <migraphx::shape::type_t DType>
struct test_contiguous : verify_program<test_contiguous<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; migraphx::shape s{DType, {4, 4, 4, 3}, {48, 4, 1, 16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("contiguous"), x); mm->add_instruction(migraphx::make_op("contiguous"), x);
assert(p.get_output_shapes().back().standard()); assert(p.get_output_shapes().back().standard());
return p; return p;
} }
}; };
template struct test_contiguous<migraphx::shape::float_type>;
template struct test_contiguous<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,19 @@ ...@@ -27,17 +27,19 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv : verify_program<test_conv> template <migraphx::shape::type_t DType>
struct test_conv : verify_program<test_conv<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights); mm->add_instruction(migraphx::make_op("convolution"), input, weights);
return p; return p;
} }
}; };
template struct test_conv<migraphx::shape::float_type>;
template struct test_conv<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,15 @@ ...@@ -27,16 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv2 : verify_program<test_conv2> template <migraphx::shape::type_t DType>
struct test_conv2 : verify_program<test_conv2<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convolution", migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
...@@ -45,3 +44,5 @@ struct test_conv2 : verify_program<test_conv2> ...@@ -45,3 +44,5 @@ struct test_conv2 : verify_program<test_conv2>
return p; return p;
} }
}; };
template struct test_conv2<migraphx::shape::float_type>;
template struct test_conv2<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,18 +27,17 @@ ...@@ -27,18 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv_add : verify_program<test_conv_add> template <migraphx::shape::type_t DType>
struct test_conv_add : verify_program<test_conv_add<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}});
auto w = mm->add_literal( auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1));
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 1)); auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 2));
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 2));
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v); auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2); auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
...@@ -46,3 +45,6 @@ struct test_conv_add : verify_program<test_conv_add> ...@@ -46,3 +45,6 @@ struct test_conv_add : verify_program<test_conv_add>
return p; return p;
} }
}; };
template struct test_conv_add<migraphx::shape::float_type>;
template struct test_conv_add<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,18 +27,17 @@ ...@@ -27,18 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides> template <migraphx::shape::type_t DType>
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}}); auto x = mm->add_parameter("x", {DType, {1, 8, 2, 2}});
auto w = mm->add_literal( auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 1));
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 1)); auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 2));
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2));
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction( auto conv2 = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v); migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
...@@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st ...@@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st
return p; return p;
} }
}; };
template struct test_conv_add_1x1_diff_strides<migraphx::shape::float_type>;
template struct test_conv_add_1x1_diff_strides<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -28,18 +28,17 @@ ...@@ -28,18 +28,17 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
struct test_conv_add_relu : verify_program<test_conv_add_relu> template <migraphx::shape::type_t DType>
struct test_conv_add_relu : verify_program<test_conv_add_relu<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = auto bias_literal =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}};
auto bias_literal = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}},
{2.0f, 2.0f, 2.0f, 2.0f}};
auto bias = mm->add_literal(bias_literal); auto bias = mm->add_literal(bias_literal);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto bcast_bias = mm->add_instruction( auto bcast_bias = mm->add_instruction(
...@@ -50,3 +49,6 @@ struct test_conv_add_relu : verify_program<test_conv_add_relu> ...@@ -50,3 +49,6 @@ struct test_conv_add_relu : verify_program<test_conv_add_relu>
return p; return p;
} }
}; };
template struct test_conv_add_relu<migraphx::shape::float_type>;
template struct test_conv_add_relu<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -29,26 +29,24 @@ ...@@ -29,26 +29,24 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu> template <migraphx::shape::type_t DType>
struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = auto l0 = migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}};
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto l0 = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}},
{2.0f, 2.0f, 2.0f, 2.0f}};
auto bias = mm->add_literal(l0); auto bias = mm->add_literal(l0);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto bcast_add = mm->add_instruction( auto bcast_add = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}), migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}),
bias); bias);
auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_add); auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_add);
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(migraphx::literal(DType, {0.0f}));
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(migraphx::literal(DType, {6.0f}));
min_val = mm->add_instruction( min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), min_val); migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), min_val);
max_val = mm->add_instruction( max_val = mm->add_instruction(
...@@ -57,3 +55,6 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu> ...@@ -57,3 +55,6 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
return p; return p;
} }
}; };
template struct test_conv_bias_clipped_relu<migraphx::shape::float_type>;
template struct test_conv_bias_clipped_relu<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -29,16 +29,17 @@ ...@@ -29,16 +29,17 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
struct test_conv_bn : verify_program<test_conv_bn> template <migraphx::shape::type_t DType>
struct test_conv_bn : verify_program<test_conv_bn<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; migraphx::shape xs{DType, {1, 3, 224, 224}};
migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; migraphx::shape ws{DType, {64, 3, 7, 7}};
migraphx::shape vars{migraphx::shape::float_type, {64}}; migraphx::shape vars{DType, {64}};
auto x = mm->add_parameter("x", xs); auto x = mm->add_parameter("x", xs);
auto w = mm->add_parameter("w", ws); auto w = mm->add_parameter("w", ws);
// non-symmetrical tiling // non-symmetrical tiling
...@@ -53,8 +54,14 @@ struct test_conv_bn : verify_program<test_conv_bn> ...@@ -53,8 +54,14 @@ struct test_conv_bn : verify_program<test_conv_bn>
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); auto rt = mm->add_literal(migraphx::literal{DType, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}});
auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}});
if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type)
{
// use 5e-2f for the fp8
eps = mm->add_literal(migraphx::literal{DType, {5e-2f}});
}
auto usq_scale = auto usq_scale =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -74,3 +81,6 @@ struct test_conv_bn : verify_program<test_conv_bn> ...@@ -74,3 +81,6 @@ struct test_conv_bn : verify_program<test_conv_bn>
return p; return p;
} }
}; };
template struct test_conv_bn<migraphx::shape::float_type>;
template struct test_conv_bn<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -29,22 +29,27 @@ ...@@ -29,22 +29,27 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
struct test_conv_bn_add : verify_program<test_conv_bn_add> template <migraphx::shape::type_t DType>
struct test_conv_bn_add : verify_program<test_conv_bn_add<DType>>
{ {
static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x) static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x)
{ {
auto bn_lens = x->get_shape().lens(); auto bn_lens = x->get_shape().lens();
auto c_len = bn_lens.at(1); auto c_len = bn_lens.at(1);
migraphx::shape vars{migraphx::shape::float_type, {c_len}}; migraphx::shape vars{DType, {c_len}};
auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len))); auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len)));
auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len))); auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len)));
auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len))); auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len)));
auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len))); auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len)));
auto rt = m.add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); auto rt = m.add_literal(migraphx::literal{DType, {0.5}});
auto eps = m.add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); auto eps = m.add_literal(migraphx::literal{DType, {1e-5f}});
if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type)
{
// use 5e-2f for the fp8
eps = m.add_literal(migraphx::literal{DType, {5e-2f}});
}
auto usq_scale = auto usq_scale =
m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
auto usq_bias = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias); auto usq_bias = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias);
...@@ -66,12 +71,12 @@ struct test_conv_bn_add : verify_program<test_conv_bn_add> ...@@ -66,12 +71,12 @@ struct test_conv_bn_add : verify_program<test_conv_bn_add>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::size_t ichannels = 64; std::size_t ichannels = 64;
std::size_t ochannels = 256; std::size_t ochannels = 256;
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}}); auto x = mm->add_parameter("x", {DType, {1, ichannels, 56, 56}});
auto w = mm->add_literal(migraphx::generate_literal( auto w =
{migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1)); mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 1));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, 56}}); auto y = mm->add_parameter("y", {DType, {1, ichannels, 56, 56}});
auto v = mm->add_literal(migraphx::generate_literal( auto v =
{migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2)); mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 2));
auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x); auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x);
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), relu1, w); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), relu1, w);
auto bn1 = add_bn(*mm, conv1); auto bn1 = add_bn(*mm, conv1);
...@@ -83,3 +88,6 @@ struct test_conv_bn_add : verify_program<test_conv_bn_add> ...@@ -83,3 +88,6 @@ struct test_conv_bn_add : verify_program<test_conv_bn_add>
return p; return p;
} }
}; };
template struct test_conv_bn_add<migraphx::shape::float_type>;
template struct test_conv_bn_add<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -30,16 +30,17 @@ ...@@ -30,16 +30,17 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> template <migraphx::shape::type_t DType>
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; migraphx::shape xs{DType, {1, 3, 224, 224}};
migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; migraphx::shape ws{DType, {64, 3, 7, 7}};
migraphx::shape vars{migraphx::shape::float_type, {64}}; migraphx::shape vars{DType, {64}};
auto x = mm->add_parameter("x", xs); auto x = mm->add_parameter("x", xs);
auto w = mm->add_parameter("w", ws); auto w = mm->add_parameter("w", ws);
auto conv = mm->add_instruction( auto conv = mm->add_instruction(
...@@ -52,9 +53,13 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> ...@@ -52,9 +53,13 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); auto rt = mm->add_literal(migraphx::literal{DType, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}});
if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type)
{
// use 5e-2f for the fp8
eps = mm->add_literal(migraphx::literal{DType, {5e-2f}});
}
auto usq_scale = auto usq_scale =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
auto usq_bias = auto usq_bias =
...@@ -82,3 +87,6 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> ...@@ -82,3 +87,6 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
return p; return p;
} }
}; };
template struct test_conv_bn_relu_pooling<migraphx::shape::float_type>;
template struct test_conv_bn_relu_pooling<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -30,22 +30,27 @@ ...@@ -30,22 +30,27 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> template <migraphx::shape::type_t DType>
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2<DType>>
{ {
static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x) static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x)
{ {
auto bn_lens = x->get_shape().lens(); auto bn_lens = x->get_shape().lens();
auto c_len = bn_lens.at(1); auto c_len = bn_lens.at(1);
migraphx::shape vars{migraphx::shape::float_type, {c_len}}; migraphx::shape vars{DType, {c_len}};
auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len))); auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len)));
auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len))); auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len)));
auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len))); auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len)));
auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len))); auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len)));
auto rt = m.add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); auto rt = m.add_literal(migraphx::literal{DType, {0.5}});
auto eps = m.add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); auto eps = m.add_literal(migraphx::literal{DType, {1e-5f}});
if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type)
{
// use 5e-2f for the fp8
eps = m.add_literal(migraphx::literal{DType, {5e-2f}});
}
auto usq_scale = auto usq_scale =
m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
auto usq_bias = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias); auto usq_bias = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias);
...@@ -66,10 +71,10 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> ...@@ -66,10 +71,10 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape xs1{migraphx::shape::float_type, {1, 512, 7, 7}}; migraphx::shape xs1{DType, {1, 512, 7, 7}};
migraphx::shape xs2{migraphx::shape::float_type, {1, 1024, 14, 14}}; migraphx::shape xs2{DType, {1, 1024, 14, 14}};
migraphx::shape ws1{migraphx::shape::float_type, {2048, 512, 1, 1}}; migraphx::shape ws1{DType, {2048, 512, 1, 1}};
migraphx::shape ws2{migraphx::shape::float_type, {2048, 1024, 1, 1}}; migraphx::shape ws2{DType, {2048, 1024, 1, 1}};
auto x1 = mm->add_parameter("x1", xs1); auto x1 = mm->add_parameter("x1", xs1);
auto w1 = mm->add_parameter("w1", ws1); auto w1 = mm->add_parameter("w1", ws1);
auto conv1 = mm->add_instruction( auto conv1 = mm->add_instruction(
...@@ -98,3 +103,6 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> ...@@ -98,3 +103,6 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
return p; return p;
} }
}; };
template struct test_conv_bn_relu_pooling2<migraphx::shape::float_type>;
template struct test_conv_bn_relu_pooling2<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,17 @@ ...@@ -27,16 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv_group_add : verify_program<test_conv_group_add> template <migraphx::shape::type_t DType>
struct test_conv_group_add : verify_program<test_conv_group_add<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 68, 28, 28}}; migraphx::shape s{DType, {1, 68, 28, 28}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto w = mm->add_parameter("w", {migraphx::shape::float_type, {68, 17, 1, 1}}); auto w = mm->add_parameter("w", {DType, {68, 17, 1, 1}});
auto b = mm->add_parameter("b", {migraphx::shape::float_type, {68}}); auto b = mm->add_parameter("b", {DType, {68}});
auto conv = mm->add_instruction(migraphx::make_op("convolution", {{"group", 4}}), x, w); auto conv = mm->add_instruction(migraphx::make_op("convolution", {{"group", 4}}), x, w);
auto bb = mm->add_instruction( auto bb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 68, 28, 28}}}), b); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 68, 28, 28}}}), b);
...@@ -44,3 +45,6 @@ struct test_conv_group_add : verify_program<test_conv_group_add> ...@@ -44,3 +45,6 @@ struct test_conv_group_add : verify_program<test_conv_group_add>
return p; return p;
} }
}; };
template struct test_conv_group_add<migraphx::shape::float_type>;
// grouped convolutions are not supported with MLIR therefore disable it
// template struct test_conv_group_add<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -28,16 +28,15 @@ ...@@ -28,16 +28,15 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
struct test_conv_pooling : verify_program<test_conv_pooling> template <migraphx::shape::type_t DType>
struct test_conv_pooling : verify_program<test_conv_pooling<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 32, 32}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto pooling = mm->add_instruction( auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv); migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv);
...@@ -45,3 +44,6 @@ struct test_conv_pooling : verify_program<test_conv_pooling> ...@@ -45,3 +44,6 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
return p; return p;
} }
}; };
template struct test_conv_pooling<migraphx::shape::float_type>;
template struct test_conv_pooling<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment