#include "verify_program.hpp" #include #include #include #include #include template struct test_arg_ops : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {2, 1, 4, 1025}}; auto param = mm->add_parameter("data", s); switch(NonStdShape) { case 0: param = mm->add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), param); break; case 1: param = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 4, 1025}}}), param); break; case 2: param = mm->add_instruction( migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {3}}}), param); break; default: break; } mm->add_instruction(T{Axis}, param); return p; } }; // transpose argmax tests template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; // transpose argmin tests template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; // broadcast argmax tests template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; // broadcast argmin tests template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; // slice argmax tests template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; // slice argmin tests template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; // default case, standard shape argmax tests template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; // default case, standard shape argmin tests template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops; template struct test_arg_ops;