Commit 8d7a8a6c authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents 25b33431 a09dc502
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_atan : verify_program<test_atan> template <migraphx::shape::type_t DType>
struct test_atan : verify_program<test_atan<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("atan"), x); mm->add_instruction(migraphx::make_op("atan"), x);
return p; return p;
} }
}; };
template struct test_atan<migraphx::shape::float_type>;
template struct test_atan<migraphx::shape::half_type>;
template struct test_atan<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -23,20 +23,24 @@ ...@@ -23,20 +23,24 @@
*/ */
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_atanh : verify_program<test_atanh> template <typename CType>
struct test_atanh : verify_program<test_atanh<CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape::type_t dtype = migraphx::shape::get_type<CType>();
migraphx::shape s{dtype, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto min_val = mm->add_literal(-0.95f); auto min_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {-0.95f}});
auto max_val = mm->add_literal(0.95f); auto max_val = mm->add_literal(migraphx::literal{migraphx::shape{dtype}, {0.95f}});
min_val = min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {16}}}), min_val);
max_val = max_val =
...@@ -46,3 +50,7 @@ struct test_atanh : verify_program<test_atanh> ...@@ -46,3 +50,7 @@ struct test_atanh : verify_program<test_atanh>
return p; return p;
} }
}; };
template struct test_atanh<float>;
template struct test_atanh<migraphx::half>;
template struct test_atanh<migraphx::fp8::fp8e4m3fnuz>;
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_ceil : verify_program<test_ceil> template <migraphx::shape::type_t DType>
struct test_ceil : verify_program<test_ceil<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::double_type, {2, 3, 4, 6}}; migraphx::shape s{DType, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s); auto param = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("ceil"), param); mm->add_instruction(migraphx::make_op("ceil"), param);
return p; return p;
}; };
}; };
template struct test_ceil<migraphx::shape::float_type>;
template struct test_ceil<migraphx::shape::half_type>;
template struct test_ceil<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,17 @@ ...@@ -27,16 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_concat_axis_0 : verify_program<test_concat_axis_0> template <migraphx::shape::type_t DType>
struct test_concat_axis_0 : verify_program<test_concat_axis_0<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
int axis = 0; int axis = 0;
migraphx::shape s0{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s0{DType, {2, 2}};
migraphx::shape s1{migraphx::shape::int32_type, {3, 2}}; migraphx::shape s1{DType, {3, 2}};
migraphx::shape s2{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s2{DType, {1, 2}};
auto l0 = mm->add_parameter("x", s0); auto l0 = mm->add_parameter("x", s0);
auto l1 = mm->add_parameter("y", s1); auto l1 = mm->add_parameter("y", s1);
auto l2 = mm->add_parameter("z", s2); auto l2 = mm->add_parameter("z", s2);
...@@ -44,3 +45,8 @@ struct test_concat_axis_0 : verify_program<test_concat_axis_0> ...@@ -44,3 +45,8 @@ struct test_concat_axis_0 : verify_program<test_concat_axis_0>
return p; return p;
} }
}; };
template struct test_concat_axis_0<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_concat_axis_0<migraphx::shape::half_type>;
template struct test_concat_axis_0<migraphx::shape::float_type>;
template struct test_concat_axis_0<migraphx::shape::int32_type>;
...@@ -29,16 +29,20 @@ ...@@ -29,16 +29,20 @@
#include <cassert> #include <cassert>
struct test_contiguous : verify_program<test_contiguous> template <migraphx::shape::type_t DType>
struct test_contiguous : verify_program<test_contiguous<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; migraphx::shape s{DType, {4, 4, 4, 3}, {48, 4, 1, 16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("contiguous"), x); mm->add_instruction(migraphx::make_op("contiguous"), x);
assert(p.get_output_shapes().back().standard()); assert(p.get_output_shapes().back().standard());
return p; return p;
} }
}; };
template struct test_contiguous<migraphx::shape::float_type>;
template struct test_contiguous<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -29,26 +29,26 @@ ...@@ -29,26 +29,26 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_convert : verify_program<test_convert> template <migraphx::shape::type_t From, migraphx::shape::type_t To>
struct test_convert : verify_program<test_convert<From, To>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::int8_type, {8, 24}}; migraphx::shape sa{From, {8, 24}};
migraphx::shape sb{migraphx::shape::int8_type, {24, 6}}; migraphx::shape sb{From, {24, 6}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto ia = mm->add_instruction( auto ia = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert", {{"target_type", migraphx::to_value(To)}}), pa);
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pa);
auto ib = mm->add_instruction( auto ib = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert", {{"target_type", migraphx::to_value(To)}}), pb);
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
pb);
mm->add_instruction(migraphx::make_op("dot"), ia, ib); mm->add_instruction(migraphx::make_op("dot"), ia, ib);
return p; return p;
}; };
}; };
template struct test_convert<migraphx::shape::int8_type, migraphx::shape::float_type>;
template struct test_convert<migraphx::shape::fp8e4m3fnuz_type, migraphx::shape::float_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_cos : verify_program<test_cos> template <migraphx::shape::type_t DType>
struct test_cos : verify_program<test_cos<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {8}}; migraphx::shape s{DType, {8}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("cos"), x); mm->add_instruction(migraphx::make_op("cos"), x);
return p; return p;
} }
}; };
template struct test_cos<migraphx::shape::float_type>;
template struct test_cos<migraphx::shape::half_type>;
template struct test_cos<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_cosh : verify_program<test_cosh> template <migraphx::shape::type_t DType>
struct test_cosh : verify_program<test_cosh<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("cosh"), x); mm->add_instruction(migraphx::make_op("cosh"), x);
return p; return p;
} }
}; };
template struct test_cosh<migraphx::shape::float_type>;
template struct test_cosh<migraphx::shape::half_type>;
template struct test_cosh<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_erf : verify_program<test_erf> template <migraphx::shape::type_t DType>
struct test_erf : verify_program<test_erf<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; migraphx::shape s{DType, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s); auto param = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("erf"), param); mm->add_instruction(migraphx::make_op("erf"), param);
return p; return p;
} }
}; };
template struct test_erf<migraphx::shape::float_type>;
template struct test_erf<migraphx::shape::half_type>;
template struct test_erf<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_exp : verify_program<test_exp> template <migraphx::shape::type_t DType>
struct test_exp : verify_program<test_exp<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {6}}; migraphx::shape s{DType, {6}};
auto x = mm->add_instruction(migraphx::make_op("abs"), mm->add_parameter("x", s)); auto x = mm->add_instruction(migraphx::make_op("abs"), mm->add_parameter("x", s));
mm->add_instruction(migraphx::make_op("exp"), x); mm->add_instruction(migraphx::make_op("exp"), x);
return p; return p;
} }
}; };
template struct test_exp<migraphx::shape::float_type>;
template struct test_exp<migraphx::shape::half_type>;
template struct test_exp<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_floor : verify_program<test_floor> template <migraphx::shape::type_t DType>
struct test_floor : verify_program<test_floor<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; migraphx::shape s{DType, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s); auto param = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("floor"), param); mm->add_instruction(migraphx::make_op("floor"), param);
return p; return p;
}; };
}; };
template struct test_floor<migraphx::shape::float_type>;
template struct test_floor<migraphx::shape::half_type>;
template struct test_floor<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -34,40 +34,52 @@ ...@@ -34,40 +34,52 @@
Adding this because HIP fmod sign changes when y = 0 resulting in nan and -nan not beign Adding this because HIP fmod sign changes when y = 0 resulting in nan and -nan not beign
consistent between ref and gpu implementations. consistent between ref and gpu implementations.
*/ */
migraphx::instruction_ref add_epsilon(migraphx::module& m, migraphx::instruction_ref y) migraphx::instruction_ref add_epsilon(migraphx::module& m,
migraphx::instruction_ref y,
migraphx::shape::type_t dtype = migraphx::shape::float_type)
{ {
auto zero = m.add_literal(0.0f); auto zero = m.add_literal(migraphx::literal{migraphx::shape{dtype}, {0.0f}});
auto eps = m.add_literal(1e-3f); auto eps = m.add_literal(migraphx::literal{migraphx::shape{dtype}, {1e-3f}});
auto op_y = add_common_op(m, migraphx::make_op("equal"), {y, zero}); auto op_y = add_common_op(m, migraphx::make_op("equal"), {y, zero});
return add_common_op(m, migraphx::make_op("where"), {op_y, eps, y}); return add_common_op(m, migraphx::make_op("where"), {op_y, eps, y});
} }
struct test_fmod : verify_program<test_fmod> template <migraphx::shape::type_t DType>
struct test_fmod : verify_program<test_fmod<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {64}}; migraphx::shape s{DType, {64}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto op_where = add_epsilon(*mm, y); auto op_where = add_epsilon(*mm, y, DType);
mm->add_instruction(migraphx::make_op("fmod"), x, op_where); mm->add_instruction(migraphx::make_op("fmod"), x, op_where);
return p; return p;
} }
}; };
template struct test_fmod<migraphx::shape::float_type>;
template struct test_fmod<migraphx::shape::half_type>;
template struct test_fmod<migraphx::shape::fp8e4m3fnuz_type>;
struct test_mod : verify_program<test_mod> template <migraphx::shape::type_t DType>
struct test_mod : verify_program<test_mod<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {64}}; migraphx::shape s{DType, {64}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto op_where = add_epsilon(*mm, y); auto op_where = add_epsilon(*mm, y, DType);
mm->add_instruction(migraphx::make_op("mod"), x, op_where); mm->add_instruction(migraphx::make_op("mod"), x, op_where);
return p; return p;
} }
}; };
template struct test_mod<migraphx::shape::float_type>;
// TODO: Fix half type test
// template struct test_mod<migraphx::shape::half_type>;
template struct test_mod<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,14 @@ ...@@ -27,14 +27,14 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <int Axis> template <int Axis, migraphx::shape::type_t DType>
struct test_gather : verify_program<test_gather<Axis>> struct test_gather : verify_program<test_gather<Axis, DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; migraphx::shape s{DType, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1}; std::vector<int> indices{1, 2, 2, 1};
auto a0 = mm->add_parameter("data", s); auto a0 = mm->add_parameter("data", s);
...@@ -46,6 +46,10 @@ struct test_gather : verify_program<test_gather<Axis>> ...@@ -46,6 +46,10 @@ struct test_gather : verify_program<test_gather<Axis>>
}; };
// Standard gather test // Standard gather test
template struct test_gather<0>; template struct test_gather<0, migraphx::shape::float_type>;
template struct test_gather<0, migraphx::shape::half_type>;
template struct test_gather<0, migraphx::shape::fp8e4m3fnuz_type>;
// Test Negative axis // Test Negative axis
template struct test_gather<-2>; template struct test_gather<-2, migraphx::shape::float_type>;
template struct test_gather<-2, migraphx::shape::half_type>;
template struct test_gather<-2, migraphx::shape::fp8e4m3fnuz_type>;
...@@ -26,13 +26,14 @@ ...@@ -26,13 +26,14 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_gathernd_default : verify_program<test_gathernd_default> template <migraphx::shape::type_t DType>
struct test_gathernd_default : verify_program<test_gathernd_default<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape ds{migraphx::shape::float_type, {2, 2}}; migraphx::shape ds{DType, {2, 2}};
migraphx::shape is{migraphx::shape::int64_type, {2, 2}}; migraphx::shape is{migraphx::shape::int64_type, {2, 2}};
std::vector<int64_t> indices{0, 0, 1, 1}; std::vector<int64_t> indices{0, 0, 1, 1};
auto a0 = mm->add_parameter("data", ds); auto a0 = mm->add_parameter("data", ds);
...@@ -41,3 +42,7 @@ struct test_gathernd_default : verify_program<test_gathernd_default> ...@@ -41,3 +42,7 @@ struct test_gathernd_default : verify_program<test_gathernd_default>
return p; return p;
} }
}; };
template struct test_gathernd_default<migraphx::shape::float_type>;
template struct test_gathernd_default<migraphx::shape::half_type>;
template struct test_gathernd_default<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -26,16 +26,20 @@ ...@@ -26,16 +26,20 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType>
struct test_gemm : verify_program<test_gemm> struct test_gemm : verify_program<test_gemm<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); auto a = mm->add_parameter("a", migraphx::shape{DType, {4, 5}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); auto b = mm->add_parameter("b", migraphx::shape{DType, {5, 3}});
mm->add_instruction(migraphx::make_op("dot"), a, b); mm->add_instruction(migraphx::make_op("dot"), a, b);
return p; return p;
} }
}; };
template struct test_gemm<migraphx::shape::float_type>;
template struct test_gemm<migraphx::shape::half_type>;
template struct test_gemm<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -28,15 +28,16 @@ ...@@ -28,15 +28,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_gemm_copy : verify_program<test_gemm_copy> template <migraphx::shape::type_t DType>
struct test_gemm_copy : verify_program<test_gemm_copy<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sa{DType, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sb{DType, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {1, 8}}; migraphx::shape sc{DType, {1, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc); auto pc = mm->add_parameter("c", sc);
...@@ -46,3 +47,7 @@ struct test_gemm_copy : verify_program<test_gemm_copy> ...@@ -46,3 +47,7 @@ struct test_gemm_copy : verify_program<test_gemm_copy>
return p; return p;
} }
}; };
template struct test_gemm_copy<migraphx::shape::float_type>;
template struct test_gemm_copy<migraphx::shape::half_type>;
template struct test_gemm_copy<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,19 @@ ...@@ -27,15 +27,19 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_gemm_ex : verify_program<test_gemm_ex> template <migraphx::shape::type_t DType>
struct test_gemm_ex : verify_program<test_gemm_ex<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 4, 5}}); auto a = mm->add_parameter("a", migraphx::shape{DType, {1, 1, 4, 5}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); auto b = mm->add_parameter("b", migraphx::shape{DType, {1, 1, 5, 3}});
mm->add_instruction(migraphx::make_op("dot"), a, b); mm->add_instruction(migraphx::make_op("dot"), a, b);
return p; return p;
} }
}; };
template struct test_gemm_ex<migraphx::shape::float_type>;
template struct test_gemm_ex<migraphx::shape::half_type>;
template struct test_gemm_ex<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_gemm_transposea : verify_program<test_gemm_transposea> template <migraphx::shape::type_t DType>
struct test_gemm_transposea : verify_program<test_gemm_transposea<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); auto a = mm->add_parameter("a", migraphx::shape{DType, {5, 4}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {5, 3}}); auto b = mm->add_parameter("b", migraphx::shape{DType, {5, 3}});
auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a); auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a);
mm->add_instruction(migraphx::make_op("dot"), at, b); mm->add_instruction(migraphx::make_op("dot"), at, b);
return p; return p;
} }
}; };
template struct test_gemm_transposea<migraphx::shape::float_type>;
template struct test_gemm_transposea<migraphx::shape::half_type>;
template struct test_gemm_transposea<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,22 @@ ...@@ -27,17 +27,22 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex> template <migraphx::shape::type_t DType>
struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); auto a = mm->add_parameter("a", migraphx::shape{DType, {1, 1, 5, 4}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); auto b = mm->add_parameter("b", migraphx::shape{DType, {1, 1, 5, 3}});
auto at = auto at =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), a); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), a);
mm->add_instruction(migraphx::make_op("dot"), at, b); mm->add_instruction(migraphx::make_op("dot"), at, b);
return p; return p;
} }
}; };
template struct test_gemm_transposea_ex<migraphx::shape::float_type>;
template struct test_gemm_transposea_ex<migraphx::shape::half_type>;
template struct test_gemm_transposea_ex<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,22 @@ ...@@ -27,17 +27,22 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_gemm_transposeab : verify_program<test_gemm_transposeab> template <migraphx::shape::type_t DType>
struct test_gemm_transposeab : verify_program<test_gemm_transposeab<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); auto a = mm->add_parameter("a", migraphx::shape{DType, {5, 4}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); auto b = mm->add_parameter("b", migraphx::shape{DType, {3, 5}});
auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a); auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a);
auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b);
mm->add_instruction(migraphx::make_op("dot"), at, bt); mm->add_instruction(migraphx::make_op("dot"), at, bt);
return p; return p;
} }
}; };
template struct test_gemm_transposeab<migraphx::shape::float_type>;
template struct test_gemm_transposeab<migraphx::shape::half_type>;
template struct test_gemm_transposeab<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment