Unverified Commit 6d0b6bcf authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add FP8 rocblas gemm support (#2473)

parent e3e00547
...@@ -27,17 +27,22 @@ ...@@ -27,17 +27,22 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex> template <migraphx::shape::type_t DType>
struct test_gemm_transposea_ex : verify_program<test_gemm_transposea_ex<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 4}}); auto a = mm->add_parameter("a", migraphx::shape{DType, {1, 1, 5, 4}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 3}}); auto b = mm->add_parameter("b", migraphx::shape{DType, {1, 1, 5, 3}});
auto at = auto at =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), a); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), a);
mm->add_instruction(migraphx::make_op("dot"), at, b); mm->add_instruction(migraphx::make_op("dot"), at, b);
return p; return p;
} }
}; };
template struct test_gemm_transposea_ex<migraphx::shape::float_type>;
template struct test_gemm_transposea_ex<migraphx::shape::half_type>;
template struct test_gemm_transposea_ex<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,22 @@ ...@@ -27,17 +27,22 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_gemm_transposeab : verify_program<test_gemm_transposeab> template <migraphx::shape::type_t DType>
struct test_gemm_transposeab : verify_program<test_gemm_transposeab<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {5, 4}}); auto a = mm->add_parameter("a", migraphx::shape{DType, {5, 4}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); auto b = mm->add_parameter("b", migraphx::shape{DType, {3, 5}});
auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a); auto at = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), a);
auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b);
mm->add_instruction(migraphx::make_op("dot"), at, bt); mm->add_instruction(migraphx::make_op("dot"), at, bt);
return p; return p;
} }
}; };
template struct test_gemm_transposeab<migraphx::shape::float_type>;
template struct test_gemm_transposeab<migraphx::shape::half_type>;
template struct test_gemm_transposeab<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_gemm_transposeb : verify_program<test_gemm_transposeb> template <migraphx::shape::type_t DType>
struct test_gemm_transposeb : verify_program<test_gemm_transposeb<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {4, 5}}); auto a = mm->add_parameter("a", migraphx::shape{DType, {4, 5}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {3, 5}}); auto b = mm->add_parameter("b", migraphx::shape{DType, {3, 5}});
auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); auto bt = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b);
mm->add_instruction(migraphx::make_op("dot"), a, bt); mm->add_instruction(migraphx::make_op("dot"), a, bt);
return p; return p;
} }
}; };
template struct test_gemm_transposeb<migraphx::shape::float_type>;
template struct test_gemm_transposeb<migraphx::shape::half_type>;
template struct test_gemm_transposeb<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,22 @@ ...@@ -27,17 +27,22 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex> template <migraphx::shape::type_t DType>
struct test_gemm_transposeb_ex : verify_program<test_gemm_transposeb_ex<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {1, 4, 5}}); auto a = mm->add_parameter("a", migraphx::shape{DType, {1, 4, 5}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {1, 3, 5}}); auto b = mm->add_parameter("b", migraphx::shape{DType, {1, 3, 5}});
auto bt = auto bt =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b);
mm->add_instruction(migraphx::make_op("dot"), a, bt); mm->add_instruction(migraphx::make_op("dot"), a, bt);
return p; return p;
} }
}; };
template struct test_gemm_transposeb_ex<migraphx::shape::float_type>;
template struct test_gemm_transposeb_ex<migraphx::shape::half_type>;
template struct test_gemm_transposeb_ex<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,17 @@ ...@@ -27,17 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_mul_dot_a : verify_program<test_mul_dot_a> template <migraphx::shape::type_t DType>
struct test_mul_dot_a : verify_program<test_mul_dot_a<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; migraphx::shape as{DType, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; migraphx::shape bs{DType, {2, 32, 128}};
auto a = mm->add_parameter("input", as); auto a = mm->add_parameter("input", as);
auto lit = auto lit = mm->add_literal(migraphx::generate_literal({DType, {1, 1, 32}}));
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 1, 32}}));
auto litb = mm->add_instruction( auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit); migraphx::make_op("multibroadcast", {{"out_lens", as.lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), a, litb); auto mul = mm->add_instruction(migraphx::make_op("mul"), a, litb);
...@@ -47,3 +47,7 @@ struct test_mul_dot_a : verify_program<test_mul_dot_a> ...@@ -47,3 +47,7 @@ struct test_mul_dot_a : verify_program<test_mul_dot_a>
return p; return p;
} }
}; };
template struct test_mul_dot_a<migraphx::shape::float_type>;
template struct test_mul_dot_a<migraphx::shape::half_type>;
template struct test_mul_dot_a<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,18 @@ ...@@ -27,17 +27,18 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_mul_dot_b : verify_program<test_mul_dot_b> template <migraphx::shape::type_t DType>
struct test_mul_dot_b : verify_program<test_mul_dot_b<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; migraphx::shape as{DType, {2, 256, 32}};
migraphx::shape bs{migraphx::shape::float_type, {2, 32, 128}}; migraphx::shape bs{DType, {2, 32, 128}};
auto b = mm->add_parameter("input", bs); auto b = mm->add_parameter("input", bs);
auto lit = auto lit = mm->add_literal(migraphx::generate_literal({DType, {1, 32, 1}}));
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {1, 32, 1}}));
auto litb = mm->add_instruction( auto litb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit); migraphx::make_op("multibroadcast", {{"out_lens", bs.lens()}}), lit);
auto mul = mm->add_instruction(migraphx::make_op("mul"), b, litb); auto mul = mm->add_instruction(migraphx::make_op("mul"), b, litb);
...@@ -47,3 +48,7 @@ struct test_mul_dot_b : verify_program<test_mul_dot_b> ...@@ -47,3 +48,7 @@ struct test_mul_dot_b : verify_program<test_mul_dot_b>
return p; return p;
} }
}; };
template struct test_mul_dot_b<migraphx::shape::float_type>;
template struct test_mul_dot_b<migraphx::shape::half_type>;
template struct test_mul_dot_b<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,17 @@ ...@@ -27,15 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
template <migraphx::shape::type_t DType>
struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 32, 64}}; migraphx::shape m1_shape{DType, {2, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}}; migraphx::shape m2_shape{DType, {64, 64}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 32, 192}}; migraphx::shape m3_shape{DType, {2, 32, 192}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape)); auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}), l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 64, 64}}}),
...@@ -56,3 +58,7 @@ struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1> ...@@ -56,3 +58,7 @@ struct test_unbatched_gemm_1 : verify_program<test_unbatched_gemm_1>
return p; return p;
} }
}; };
template struct test_unbatched_gemm_1<migraphx::shape::float_type>;
template struct test_unbatched_gemm_1<migraphx::shape::half_type>;
template struct test_unbatched_gemm_1<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,16 @@ ...@@ -27,14 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
template <migraphx::shape::type_t DType>
struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {4, 32, 64}}; migraphx::shape m1_shape{DType, {4, 32, 64}};
migraphx::shape m2_shape{migraphx::shape::float_type, {64, 64}}; migraphx::shape m2_shape{DType, {64, 64}};
auto l1 = mm->add_parameter("1", m1_shape); auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape)); auto l2 = mm->add_literal(migraphx::generate_literal(m2_shape));
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 64, 64}}}), l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 64, 64}}}),
...@@ -44,3 +46,7 @@ struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2> ...@@ -44,3 +46,7 @@ struct test_unbatched_gemm_2 : verify_program<test_unbatched_gemm_2>
return p; return p;
} }
}; };
template struct test_unbatched_gemm_2<migraphx::shape::float_type>;
template struct test_unbatched_gemm_2<migraphx::shape::half_type>;
template struct test_unbatched_gemm_2<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment