Commit d5fa82db authored by Umang Yadav's avatar Umang Yadav
Browse files

add quant_dot support for fp8

parent 7e80f627
...@@ -44,9 +44,10 @@ struct quant_dot ...@@ -44,9 +44,10 @@ struct quant_dot
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if(t != shape::int8_type) std::set<migraphx::shape::type_t> suppported_types = {shape::int8_type, shape::fp8e4m3fnuz_type};
if(not contains(suppported_types, t))
{ {
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t"); MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and fp8e4m3fnuz_type");
} }
if(not std::all_of( if(not std::all_of(
...@@ -73,6 +74,10 @@ struct quant_dot ...@@ -73,6 +74,10 @@ struct quant_dot
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
if(t == shape::fp8e4m3fnuz_type)
{
return {shape::float_type, out_lens};
} // else int8 gemm
return {shape::int32_type, out_lens}; return {shape::int32_type, out_lens};
} }
}; };
......
...@@ -183,6 +183,11 @@ struct find_nested_convert ...@@ -183,6 +183,11 @@ struct find_nested_convert
auto x = ins->inputs().front(); auto x = ins->inputs().front();
auto input = x->inputs().front(); auto input = x->inputs().front();
while(input->name() == "convert")
{
input = input->inputs().front();
}
if(ins->get_shape() != input->get_shape()) if(ins->get_shape() != input->get_shape())
return; return;
......
...@@ -112,7 +112,7 @@ struct rocblas_gemm ...@@ -112,7 +112,7 @@ struct rocblas_gemm
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm" or output_shape.type() == migraphx::shape::float_type)
{ {
gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx); gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/ref/lowering.hpp> #include <migraphx/ref/lowering.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
...@@ -307,19 +308,46 @@ struct ref_quant_gemm ...@@ -307,19 +308,46 @@ struct ref_quant_gemm
{ {
argument result{output_shape}; argument result{output_shape};
// first, convert the args[0] and args[1] from int8_t to int32_t // first, convert the args[0] and args[1] from int8_t to int32_t
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}}; argument arg_0{{output_shape.type(), {args.at(0).get_shape().lens()}}};
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}}; argument arg_1{{output_shape.type(), {args.at(1).get_shape().lens()}}};
arg_0.visit([&](auto output) { if(output_shape.type() == migraphx::shape::float_type)
args.at(0).visit( {
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); }); arg_0.visit([&](auto output) {
}); args.at(0).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<float>(x);
});
});
});
arg_1.visit([&](auto output) { arg_1.visit([&](auto output) {
args.at(1).visit( args.at(1).visit([&](auto input) {
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); }); std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
}); return static_cast<float>(x);
});
});
});
migemm(result, arg_0, arg_1, 1.0f, 0.0f);
}
else if(output_shape.type() == migraphx::shape::int32_type)
{
arg_0.visit([&](auto output) {
args.at(0).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<int32_t>(x);
});
});
});
migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0}); arg_1.visit([&](auto output) {
args.at(1).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<int32_t>(x);
});
});
});
migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0});
}
return result; return result;
} }
......
...@@ -24,19 +24,23 @@ ...@@ -24,19 +24,23 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1> template <typename DType, typename CType>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}}; auto dtype = migraphx::shape::get_type<DType>{};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}}; auto ctype = migraphx::shape::get_type<CType>{};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; migraphx::shape m1_shape{dtype, {3, 2, 8, 2}};
migraphx::shape m2_shape{dtype, {3, 2, 7, 8}};
migraphx::shape m3_shape{ctype, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction( auto tl1 = mm->add_instruction(
...@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1> ...@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); migraphx::add_apply_alpha_beta(
*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2});
return p; return p;
} }
}; };
template struct batch_quant_dot_1<int8_t, int32_t>;
template struct batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,15 +28,16 @@ ...@@ -28,15 +28,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2> template <migraphx::shape::type_t DType, migraphx::shape::type_t CType>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}}; migraphx::shape m1_shape{DType, {3, 2, 2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}}; migraphx::shape m2_shape{DType, {3, 2, 8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; migraphx::shape m3_shape{CType, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -45,3 +46,5 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2> ...@@ -45,3 +46,5 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
return p; return p;
} }
}; };
template struct batch_quant_dot_2<migraphx::shape::int8_type, migraphx::shape::int32_type>;
template struct batch_quant_dot_2<migraphx::shape::fp8e4m3fnuz_type, migraphx::shape::float_type>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3> template <migraphx::shape::type_t DType>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}}; migraphx::shape m1_shape{DType, {3, 2, 2, 6}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}}; migraphx::shape m2_shape{DType, {3, 2, 6, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3> ...@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
return p; return p;
} }
}; };
template struct batch_quant_dot_3<migraphx::shape::int8_type>;
template struct batch_quant_dot_3<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4> template <migraphx::shape::type_t DType>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}}; migraphx::shape m1_shape{DType, {2, 4, 6, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}}; migraphx::shape m2_shape{DType, {7, 2, 6, 3}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4> ...@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
return p; return p;
} }
}; };
template struct batch_quant_dot_4<migraphx::shape::int8_type>;
template struct batch_quant_dot_4<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5> template <migraphx::shape::type_t DType>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}}; migraphx::shape m1_shape{DType, {3, 2, 7, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}}; migraphx::shape m2_shape{DType, {3, 2, 5, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5> ...@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
return p; return p;
} }
}; };
template struct batch_quant_dot_5<migraphx::shape::int8_type>;
template struct batch_quant_dot_5<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -25,23 +25,31 @@ ...@@ -25,23 +25,31 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1> template <typename DType, typename CType>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; auto ctype = migraphx::shape::get_type<CType>();
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m1_shape{dtype, {2, 8}};
migraphx::shape m2_shape{dtype, {8, 7}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1); migraphx::add_apply_alpha_beta(
*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{1});
return p; return p;
} }
}; };
template struct quant_dot_3args_1<int8_t, int32_t>;
template struct quant_dot_3args_1<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,22 +28,29 @@ ...@@ -28,22 +28,29 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2> template <typename DType, typename CType>
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; auto ctype = migraphx::shape::get_type<CType>();
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m1_shape{dtype, {8, 2}};
migraphx::shape m2_shape{dtype, {8, 7}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3); migraphx::add_apply_alpha_beta(
*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{3});
return p; return p;
} }
}; };
template struct quant_dot_3args_2<int8_t, int32_t>;
template struct quant_dot_3args_2<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,22 +28,28 @@ ...@@ -28,22 +28,28 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3> template <typename DType, typename CType>
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; auto ctype = migraphx::shape::get_type<CType>();
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m1_shape{dtype, {2, 8}};
migraphx::shape m2_shape{dtype, {7, 8}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3); migraphx::add_apply_alpha_beta(
*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), CType{2}, CType{3});
return p; return p;
} }
}; };
template struct quant_dot_3args_3<int8_t, int32_t>;
template struct quant_dot_3args_3<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,15 +28,18 @@ ...@@ -28,15 +28,18 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4> template <typename DType, typename CType>
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; auto ctype = migraphx::shape::get_type<CType>();
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m1_shape{dtype, {8, 2}};
migraphx::shape m2_shape{dtype, {7, 8}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 =
...@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4> ...@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); migraphx::add_apply_alpha_beta(
*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2});
return p; return p;
} }
}; };
template struct quant_dot_3args_4<int8_t, int32_t>;
template struct quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,14 +28,17 @@ ...@@ -28,14 +28,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_5 : verify_program<quant_dot_3args_5> template <typename DType, typename CType>
struct quant_dot_3args_5 : verify_program<quant_dot_3args_5<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {6, 2}}; auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 6}};
migraphx::shape m1_shape{dtype, {6, 2}};
migraphx::shape m2_shape{dtype, {7, 6}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 =
...@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5> ...@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3); migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), CType{3});
return p; return p;
} }
}; };
template struct quant_dot_3args_5<int8_t, int32_t>;
template struct quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment