Commit 91cc9c7c authored by Umang Yadav's avatar Umang Yadav
Browse files

Add math and reduce tests

parent bd0ae5fa
...@@ -41,3 +41,4 @@ struct test_pow : verify_program<test_pow> ...@@ -41,3 +41,4 @@ struct test_pow : verify_program<test_pow>
return p; return p;
} }
}; };
// TODO: add fp8 tests
...@@ -22,19 +22,21 @@ ...@@ -22,19 +22,21 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/shape.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
struct test_reduce_add : verify_program<test_reduce_add> template <migraphx::shape::type_t DType>
struct test_reduce_add : verify_program<test_reduce_add<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 1000, 2, 2}}; migraphx::shape s{DType, {4, 1000, 2, 2}};
migraphx::shape bs{migraphx::shape::half_type, {1, 32, 128}}; migraphx::shape bs{migraphx::shape::half_type, {1, 32, 128}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto reduce_mean = auto reduce_mean =
...@@ -46,3 +48,6 @@ struct test_reduce_add : verify_program<test_reduce_add> ...@@ -46,3 +48,6 @@ struct test_reduce_add : verify_program<test_reduce_add>
return p; return p;
}; };
}; };
template struct test_reduce_add<migraphx::shape::float_type>;
template struct test_reduce_add<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -28,14 +28,14 @@ ...@@ -28,14 +28,14 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc> template <migraphx::shape::type_t DType>
struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto s = migraphx::shape::from_permutation( auto s = migraphx::shape::from_permutation(DType, {4, 256, 2, 2}, {0, 2, 3, 1});
migraphx::shape::float_type, {4, 256, 2, 2}, {0, 2, 3, 1});
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto reduce = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), x); auto reduce = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), x);
auto abs = mm->add_instruction(migraphx::make_op("abs"), reduce); auto abs = mm->add_instruction(migraphx::make_op("abs"), reduce);
...@@ -44,3 +44,7 @@ struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc> ...@@ -44,3 +44,7 @@ struct test_reduce_mean_nhwc : verify_program<test_reduce_mean_nhwc>
return p; return p;
}; };
}; };
template struct test_reduce_mean_nhwc<migraphx::shape::float_type>;
template struct test_reduce_mean_nhwc<migraphx::shape::half_type>;
template struct test_reduce_mean_nhwc<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -51,6 +51,22 @@ template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shap ...@@ -51,6 +51,22 @@ template struct test_reduce_op_large<migraphx::op::reduce_min, 1, migraphx::shap
template struct test_reduce_op_large<migraphx::op::reduce_prod, 2, migraphx::shape::float_type>; template struct test_reduce_op_large<migraphx::op::reduce_prod, 2, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>; template struct test_reduce_op_large<migraphx::op::reduce_sum, 1, migraphx::shape::float_type>;
template struct test_reduce_op_large<migraphx::op::reduce_max,
1,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_mean,
1,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_min,
1,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_prod,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_large<migraphx::op::reduce_sum,
1,
migraphx::shape::fp8e4m3fnuz_type>;
struct test_reduce_mean_1 : verify_program<test_reduce_mean_1> struct test_reduce_mean_1 : verify_program<test_reduce_mean_1>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -56,3 +56,19 @@ template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::sha ...@@ -56,3 +56,19 @@ template struct test_reduce_op_small<migraphx::op::reduce_mean, 2, migraphx::sha
template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_max, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_min, 2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::shape::half_type>; template struct test_reduce_op_small<migraphx::op::reduce_prod, -2, migraphx::shape::half_type>;
template struct test_reduce_op_small<migraphx::op::reduce_sum,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_mean,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_max,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_min,
2,
migraphx::shape::fp8e4m3fnuz_type>;
template struct test_reduce_op_small<migraphx::op::reduce_prod,
-2,
migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,16 @@ ...@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_roialign : verify_program<test_roialign> template <migraphx::shape::type_t DType>
struct test_roialign : verify_program<test_roialign<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape x_s{migraphx::shape::float_type, {5, 4, 10, 10}}; migraphx::shape x_s{DType, {5, 4, 10, 10}};
migraphx::shape roi_s{migraphx::shape::float_type, {5, 4}}; migraphx::shape roi_s{DType, {5, 4}};
migraphx::shape ind_s{migraphx::shape::int64_type, {5}}; migraphx::shape ind_s{migraphx::shape::int64_type, {5}};
std::vector<int64_t> ind_vec = {0, 2, 3, 4, 1}; std::vector<int64_t> ind_vec = {0, 2, 3, 4, 1};
...@@ -44,10 +45,10 @@ struct test_roialign : verify_program<test_roialign> ...@@ -44,10 +45,10 @@ struct test_roialign : verify_program<test_roialign>
auto roi = mm->add_parameter("roi", roi_s); auto roi = mm->add_parameter("roi", roi_s);
auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec)); auto ind = mm->add_literal(migraphx::literal(ind_s, ind_vec));
auto r = mm->add_instruction(migraphx::make_op("roialign", auto r = mm->add_instruction(migraphx::make_op("roialign",
{{"spatial_scale", 1.0}, {{"spatial_scale", 1.0},
{"output_height", 5}, {"output_height", 5},
{"output_width", 5}, {"output_width", 5},
{"sampling_ratio", 2}}), {"sampling_ratio", 2}}),
x, x,
roi, roi,
ind); ind);
...@@ -56,3 +57,7 @@ struct test_roialign : verify_program<test_roialign> ...@@ -56,3 +57,7 @@ struct test_roialign : verify_program<test_roialign>
return p; return p;
} }
}; };
template struct test_roialign<migraphx::shape::float_type>;
// template struct test_roialign<migraphx::shape::half_type>;
// template struct test_roialign<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -47,3 +47,5 @@ struct test_rsqrt : verify_program<test_rsqrt> ...@@ -47,3 +47,5 @@ struct test_rsqrt : verify_program<test_rsqrt>
return p; return p;
}; };
}; };
// TOOD : Add FP8 test
\ No newline at end of file
...@@ -21,22 +21,23 @@ ...@@ -21,22 +21,23 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/shape.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_scatternd : verify_program<test_scatternd> template <migraphx::shape::type_t DType>
struct test_scatternd : verify_program<test_scatternd<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto dtype = migraphx::shape::float_type;
auto itype = migraphx::shape::int64_type; auto itype = migraphx::shape::int64_type;
migraphx::shape ds{dtype, {1}}; migraphx::shape ds{DType, {1}};
migraphx::shape is{itype, {4, 1}}; migraphx::shape is{itype, {4, 1}};
migraphx::shape us{dtype, {4}}; migraphx::shape us{DType, {4}};
std::vector<int64_t> ind_vec{4, 3, 1, 7}; std::vector<int64_t> ind_vec{4, 3, 1, 7};
auto ld = mm->add_literal(migraphx::literal{ds, {1}}); auto ld = mm->add_literal(migraphx::literal{ds, {1}});
...@@ -51,3 +52,7 @@ struct test_scatternd : verify_program<test_scatternd> ...@@ -51,3 +52,7 @@ struct test_scatternd : verify_program<test_scatternd>
return p; return p;
} }
}; };
template struct test_scatternd<migraphx::shape::float_type>;
template struct test_scatternd<migraphx::shape::half_type>;
template struct test_scatternd<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_sin : verify_program<test_sin> template <migraphx::shape::type_t DType>
struct test_sin : verify_program<test_sin<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {10}}; migraphx::shape s{DType, {10}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("sin"), x); mm->add_instruction(migraphx::make_op("sin"), x);
return p; return p;
} }
}; };
template struct test_sin<migraphx::shape::float_type>;
template struct test_sin<migraphx::shape::half_type>;
template struct test_sin<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_sinh : verify_program<test_sinh> template <migraphx::shape::type_t DType>
struct test_sinh : verify_program<test_sinh<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("sinh"), x); mm->add_instruction(migraphx::make_op("sinh"), x);
return p; return p;
} }
}; };
template struct test_sinh<migraphx::shape::float_type>;
template struct test_sinh<migraphx::shape::half_type>;
template struct test_sinh<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -48,3 +48,7 @@ template struct test_softmax<0, migraphx::shape::half_type>; ...@@ -48,3 +48,7 @@ template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>; template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>; template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>; template struct test_softmax<3, migraphx::shape::half_type>;
// template struct test_softmax<0, migraphx::shape::fp8e4m3fnuz_type>;
// template struct test_softmax<1, migraphx::shape::fp8e4m3fnuz_type>;
// template struct test_softmax<2, migraphx::shape::fp8e4m3fnuz_type>;
// template struct test_softmax<3, migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_sqrt : verify_program<test_sqrt> template <migraphx::shape::type_t DType>
struct test_sqrt : verify_program<test_sqrt<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; migraphx::shape s{DType, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s); auto param = mm->add_parameter("x", s);
auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param); auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param);
mm->add_instruction(migraphx::make_op("sqrt"), param_abs); mm->add_instruction(migraphx::make_op("sqrt"), param_abs);
return p; return p;
} }
}; };
template struct test_sqrt<migraphx::shape::float_type>;
template struct test_sqrt<migraphx::shape::half_type>;
template struct test_sqrt<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_tan : verify_program<test_tan> template <migraphx::shape::type_t DType>
struct test_tan : verify_program<test_tan<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("tan"), x); mm->add_instruction(migraphx::make_op("tan"), x);
return p; return p;
} }
}; };
template struct test_tan<migraphx::shape::float_type>;
template struct test_tan<migraphx::shape::half_type>;
template struct test_tan<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,19 @@ ...@@ -27,14 +27,19 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_tanh : verify_program<test_tanh> template <migraphx::shape::type_t DType>
struct test_tanh : verify_program<test_tanh<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("tanh"), x); mm->add_instruction(migraphx::make_op("tanh"), x);
return p; return p;
} }
}; };
template struct test_tanh<migraphx::shape::float_type>;
template struct test_tanh<migraphx::shape::half_type>;
template struct test_tanh<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,7 +27,8 @@ ...@@ -27,7 +27,8 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_where : verify_program<test_where> template <migraphx::shape::type_t DType>
struct test_where : verify_program<test_where<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -44,3 +45,7 @@ struct test_where : verify_program<test_where> ...@@ -44,3 +45,7 @@ struct test_where : verify_program<test_where>
return p; return p;
}; };
}; };
template struct test_where<migraphx::shape::float_type>;
template struct test_where<migraphx::shape::half_type>;
template struct test_where<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment