Commit 4e07dfcc authored by Umang Yadav's avatar Umang Yadav
Browse files

revert some changes

parent 050184cb
...@@ -24,23 +24,19 @@ ...@@ -24,23 +24,19 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <typename DType, typename CType> struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto dtype = migraphx::shape::get_type<DType>{}; migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}};
auto ctype = migraphx::shape::get_type<CType>{}; migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m1_shape{dtype, {3, 2, 8, 2}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
migraphx::shape m2_shape{dtype, {3, 2, 7, 8}};
migraphx::shape m3_shape{ctype, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction( auto tl1 = mm->add_instruction(
...@@ -49,11 +45,7 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1<DType, CType>> ...@@ -49,11 +45,7 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1<DType, CType>>
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta( migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2});
return p; return p;
} }
}; };
template struct batch_quant_dot_1<int8_t, int32_t>;
template struct batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,16 +28,15 @@ ...@@ -28,16 +28,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType, migraphx::shape::type_t CType> struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{DType, {3, 2, 2, 8}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}};
migraphx::shape m2_shape{DType, {3, 2, 8, 7}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}};
migraphx::shape m3_shape{CType, {3, 2, 2, 7}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -46,5 +45,3 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>> ...@@ -46,5 +45,3 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>>
return p; return p;
} }
}; };
template struct batch_quant_dot_2<migraphx::shape::int8_type, migraphx::shape::int32_type>;
template struct batch_quant_dot_2<migraphx::shape::fp8e4m3fnuz_type, migraphx::shape::float_type>;
...@@ -27,15 +27,14 @@ ...@@ -27,15 +27,14 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType> struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{DType, {3, 2, 2, 6}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}};
migraphx::shape m2_shape{DType, {3, 2, 6, 7}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -43,5 +42,3 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3<DType>> ...@@ -43,5 +42,3 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3<DType>>
return p; return p;
} }
}; };
template struct batch_quant_dot_3<migraphx::shape::int8_type>;
template struct batch_quant_dot_3<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,14 @@ ...@@ -27,15 +27,14 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType> struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{DType, {2, 4, 6, 3}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}};
migraphx::shape m2_shape{DType, {7, 2, 6, 3}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -47,5 +46,3 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4<DType>> ...@@ -47,5 +46,3 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4<DType>>
return p; return p;
} }
}; };
template struct batch_quant_dot_4<migraphx::shape::int8_type>;
template struct batch_quant_dot_4<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,14 @@ ...@@ -27,15 +27,14 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType> struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{DType, {3, 2, 7, 2}}; migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}};
migraphx::shape m2_shape{DType, {3, 2, 5, 7}}; migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -49,5 +48,3 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5<DType>> ...@@ -49,5 +48,3 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5<DType>>
return p; return p;
} }
}; };
template struct batch_quant_dot_5<migraphx::shape::int8_type>;
template struct batch_quant_dot_5<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -25,31 +25,23 @@ ...@@ -25,31 +25,23 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <typename DType, typename CType> struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto ctype = migraphx::shape::get_type<CType>(); migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
auto dtype = migraphx::shape::get_type<DType>(); migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m1_shape{dtype, {2, 8}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
migraphx::shape m2_shape{dtype, {8, 7}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta( migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{1});
return p; return p;
} }
}; };
template struct quant_dot_3args_1<int8_t, int32_t>;
template struct quant_dot_3args_1<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,29 +28,22 @@ ...@@ -28,29 +28,22 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <typename DType, typename CType> struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto ctype = migraphx::shape::get_type<CType>(); migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
auto dtype = migraphx::shape::get_type<DType>(); migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m1_shape{dtype, {8, 2}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
migraphx::shape m2_shape{dtype, {8, 7}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta( migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{3});
return p; return p;
} }
}; };
template struct quant_dot_3args_2<int8_t, int32_t>;
template struct quant_dot_3args_2<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,28 +28,22 @@ ...@@ -28,28 +28,22 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <typename DType, typename CType> struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto ctype = migraphx::shape::get_type<CType>(); migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
auto dtype = migraphx::shape::get_type<DType>(); migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m1_shape{dtype, {2, 8}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
migraphx::shape m2_shape{dtype, {7, 8}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta( migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3);
*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), CType{2}, CType{3});
return p; return p;
} }
}; };
template struct quant_dot_3args_3<int8_t, int32_t>;
template struct quant_dot_3args_3<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,18 +28,15 @@ ...@@ -28,18 +28,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <typename DType, typename CType> struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto ctype = migraphx::shape::get_type<CType>(); migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
auto dtype = migraphx::shape::get_type<DType>(); migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m1_shape{dtype, {8, 2}}; migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
migraphx::shape m2_shape{dtype, {7, 8}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 =
...@@ -48,11 +45,7 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4<DType, CType>> ...@@ -48,11 +45,7 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4<DType, CType>>
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta( migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2});
return p; return p;
} }
}; };
template struct quant_dot_3args_4<int8_t, int32_t>;
template struct quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,17 +28,14 @@ ...@@ -28,17 +28,14 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <typename DType, typename CType> struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
struct quant_dot_3args_5 : verify_program<quant_dot_3args_5<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto dtype = migraphx::shape::get_type<DType>(); migraphx::shape m1_shape{migraphx::shape::int8_type, {6, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 6}};
migraphx::shape m1_shape{dtype, {6, 2}};
migraphx::shape m2_shape{dtype, {7, 6}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 =
...@@ -46,10 +43,7 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5<DType, CType>> ...@@ -46,10 +43,7 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5<DType, CType>>
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), CType{3}); migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3);
return p; return p;
} }
}; };
template struct quant_dot_3args_5<int8_t, int32_t>;
template struct quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment