Commit d5fa82db authored by Umang Yadav's avatar Umang Yadav
Browse files

add quant_dot support for fp8

parent 7e80f627
......@@ -44,9 +44,10 @@ struct quant_dot
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(t != shape::int8_type)
std::set<migraphx::shape::type_t> suppported_types = {shape::int8_type, shape::fp8e4m3fnuz_type};
if(not contains(suppported_types, t))
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and fp8e4m3fnuz_type");
}
if(not std::all_of(
......@@ -73,6 +74,10 @@ struct quant_dot
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(t == shape::fp8e4m3fnuz_type)
{
return {shape::float_type, out_lens};
} // else int8 gemm
return {shape::int32_type, out_lens};
}
};
......
......@@ -183,6 +183,11 @@ struct find_nested_convert
auto x = ins->inputs().front();
auto input = x->inputs().front();
while(input->name() == "convert")
{
input = input->inputs().front();
}
if(ins->get_shape() != input->get_shape())
return;
......
......@@ -112,7 +112,7 @@ struct rocblas_gemm
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
if(this->name() == "gpu::gemm")
if(this->name() == "gpu::gemm" or output_shape.type() == migraphx::shape::float_type)
{
gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
}
......
......@@ -24,6 +24,7 @@
#include <migraphx/ref/lowering.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp>
......@@ -307,19 +308,46 @@ struct ref_quant_gemm
{
argument result{output_shape};
// first, convert the args[0] and args[1] from int8_t to int32_t
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
arg_0.visit([&](auto output) {
args.at(0).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
argument arg_0{{output_shape.type(), {args.at(0).get_shape().lens()}}};
argument arg_1{{output_shape.type(), {args.at(1).get_shape().lens()}}};
if(output_shape.type() == migraphx::shape::float_type)
{
arg_0.visit([&](auto output) {
args.at(0).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<float>(x);
});
});
});
arg_1.visit([&](auto output) {
args.at(1).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});
arg_1.visit([&](auto output) {
args.at(1).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<float>(x);
});
});
});
migemm(result, arg_0, arg_1, 1.0f, 0.0f);
}
else if(output_shape.type() == migraphx::shape::int32_type)
{
arg_0.visit([&](auto output) {
args.at(0).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<int32_t>(x);
});
});
});
migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0});
arg_1.visit([&](auto output) {
args.at(1).visit([&](auto input) {
std::transform(input.begin(), input.end(), output.begin(), [&](const auto x) {
return static_cast<int32_t>(x);
});
});
});
migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0});
}
return result;
}
......
......@@ -24,19 +24,23 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
template <typename DType, typename CType>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto dtype = migraphx::shape::get_type<DType>{};
auto ctype = migraphx::shape::get_type<CType>{};
migraphx::shape m1_shape{dtype, {3, 2, 8, 2}};
migraphx::shape m2_shape{dtype, {3, 2, 7, 8}};
migraphx::shape m3_shape{ctype, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction(
......@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
migraphx::add_apply_alpha_beta(
*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2});
return p;
}
};
template struct batch_quant_dot_1<int8_t, int32_t>;
template struct batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>;
......@@ -28,15 +28,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
template <migraphx::shape::type_t DType, migraphx::shape::type_t CType>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
migraphx::shape m1_shape{DType, {3, 2, 2, 8}};
migraphx::shape m2_shape{DType, {3, 2, 8, 7}};
migraphx::shape m3_shape{CType, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
......@@ -45,3 +46,5 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
return p;
}
};
template struct batch_quant_dot_2<migraphx::shape::int8_type, migraphx::shape::int32_type>;
template struct batch_quant_dot_2<migraphx::shape::fp8e4m3fnuz_type, migraphx::shape::float_type>;
......@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
template <migraphx::shape::type_t DType>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}};
migraphx::shape m1_shape{DType, {3, 2, 2, 6}};
migraphx::shape m2_shape{DType, {3, 2, 6, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
......@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
return p;
}
};
template struct batch_quant_dot_3<migraphx::shape::int8_type>;
template struct batch_quant_dot_3<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
template <migraphx::shape::type_t DType>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}};
migraphx::shape m1_shape{DType, {2, 4, 6, 3}};
migraphx::shape m2_shape{DType, {7, 2, 6, 3}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
......@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
return p;
}
};
template struct batch_quant_dot_4<migraphx::shape::int8_type>;
template struct batch_quant_dot_4<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
template <migraphx::shape::type_t DType>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}};
migraphx::shape m1_shape{DType, {3, 2, 7, 2}};
migraphx::shape m2_shape{DType, {3, 2, 5, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
......@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
return p;
}
};
template struct batch_quant_dot_5<migraphx::shape::int8_type>;
template struct batch_quant_dot_5<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -25,23 +25,31 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1>
template <typename DType, typename CType>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto ctype = migraphx::shape::get_type<CType>();
auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m1_shape{dtype, {2, 8}};
migraphx::shape m2_shape{dtype, {8, 7}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1);
migraphx::add_apply_alpha_beta(
*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{1});
return p;
}
};
template struct quant_dot_3args_1<int8_t, int32_t>;
template struct quant_dot_3args_1<migraphx::fp8::fp8e4m3fnuz, float>;
......@@ -28,22 +28,29 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2>
template <typename DType, typename CType>
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto ctype = migraphx::shape::get_type<CType>();
auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m1_shape{dtype, {8, 2}};
migraphx::shape m2_shape{dtype, {8, 7}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3);
migraphx::add_apply_alpha_beta(
*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{3});
return p;
}
};
template struct quant_dot_3args_2<int8_t, int32_t>;
template struct quant_dot_3args_2<migraphx::fp8::fp8e4m3fnuz, float>;
......@@ -28,22 +28,28 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3>
template <typename DType, typename CType>
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto ctype = migraphx::shape::get_type<CType>();
auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m1_shape{dtype, {2, 8}};
migraphx::shape m2_shape{dtype, {7, 8}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3);
migraphx::add_apply_alpha_beta(
*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), CType{2}, CType{3});
return p;
}
};
template struct quant_dot_3args_3<int8_t, int32_t>;
template struct quant_dot_3args_3<migraphx::fp8::fp8e4m3fnuz, float>;
......@@ -28,15 +28,18 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
template <typename DType, typename CType>
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}};
auto ctype = migraphx::shape::get_type<CType>();
auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m1_shape{dtype, {8, 2}};
migraphx::shape m2_shape{dtype, {7, 8}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 =
......@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2);
migraphx::add_apply_alpha_beta(
*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2});
return p;
}
};
template struct quant_dot_3args_4<int8_t, int32_t>;
template struct quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>;
......@@ -28,14 +28,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
template <typename DType, typename CType>
struct quant_dot_3args_5 : verify_program<quant_dot_3args_5<DType, CType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {6, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 6}};
auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m1_shape{dtype, {6, 2}};
migraphx::shape m2_shape{dtype, {7, 6}};
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 =
......@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), CType{3});
return p;
}
};
template struct quant_dot_3args_5<int8_t, int32_t>;
template struct quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment