"sgl-router/git@developer.sourcefind.cn:change/sglang.git" did not exist on "a4a3d8239365d2fc16e2dc707605373ea3f35ded"
Unverified Commit 39bc6161 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Int8 gemm support (#811)



* add a flag to indicate int8x4 input format

* clang format

* code backup

* clang format

* code backup

* clang format

* code backup

* clang format

* code backup

* clang format

* code backup

* clang format

* remove log info

* remove unnecessary changes

* fix cppcheck error

* add unit tests to have more code coverage

* clang format

* add debug info

* remove log info

* fix cppcheck error

* clang format

* clang format

* add one more unit tests for more scenarios

* fix cppcheck error

* clang format

* fix review comments

* clang format

* rename p to m

* fix review comments

* refine unit tests

* clang format

* refine unit tests and fixed a bug

* clang format

* fix build error related to rocm4.2

* fix a bug related to alpha and beta

* refine two unit tests related to int8_gemm

* fix cppcheck error

* refine unit test to pass on mi100

* add unit test for packing int8 args

* clang format

* change unit tests back

* disable some unit tests for gpu

* clang format

* refine unit tests to run on mi100

* clang format

* refine unit tests

* refine unit tests

* clang format

* change back a unit test
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 0e5682cd
...@@ -37,8 +37,12 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs) ...@@ -37,8 +37,12 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
} }
template <class T> template <class T>
void gemm_impl( void gemm_impl(context& ctx,
context& ctx, const shape& output_shape, const std::vector<argument>& args, T alpha, T beta) const shape& output_shape,
const std::vector<argument>& args,
T alpha,
T beta,
bool int8_x4_format)
{ {
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
...@@ -62,6 +66,14 @@ void gemm_impl( ...@@ -62,6 +66,14 @@ void gemm_impl(
} }
auto compute_type = output_type; auto compute_type = output_type;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
rocblas_gemm_flags flag = rocblas_gemm_flags_none;
#endif
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
...@@ -72,7 +84,7 @@ void gemm_impl( ...@@ -72,7 +84,7 @@ void gemm_impl(
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); }; auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0) if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{ {
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!"); MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
} }
...@@ -109,11 +121,7 @@ void gemm_impl( ...@@ -109,11 +121,7 @@ void gemm_impl(
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 flag);
rocblas_gemm_flags_pack_int8x4);
#else
0);
#endif
} }
else else
{ {
...@@ -146,11 +154,7 @@ void gemm_impl( ...@@ -146,11 +154,7 @@ void gemm_impl(
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 flag);
rocblas_gemm_flags_pack_int8x4);
#else
0);
#endif
} }
}); });
} }
...@@ -159,18 +163,20 @@ void gemm(context& ctx, ...@@ -159,18 +163,20 @@ void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta) float beta,
bool int8_x4_format)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format);
} }
void gemm(context& ctx, void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta) int32_t beta,
bool int8_x4_format)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format);
} }
} // namespace gpu } // namespace gpu
......
...@@ -19,11 +19,13 @@ template <class Op> ...@@ -19,11 +19,13 @@ template <class Op>
struct rocblas_gemm struct rocblas_gemm
{ {
Op op; Op op;
bool int8_x4_format = true;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return migraphx::reflect(self.op, f); return pack_join(migraphx::reflect(self.op, f),
pack(f(self.int8_x4_format, "int8_x4_format")));
} }
std::string name() const std::string name() const
...@@ -43,22 +45,13 @@ struct rocblas_gemm ...@@ -43,22 +45,13 @@ struct rocblas_gemm
batch_not_transposed(inputs[0].strides()); batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides()); batch_not_transposed(inputs[1].strides());
std::size_t kdim = inputs[0].lens().size() - 1;
// k be multiple of 4
if(op.name() == "quant_dot" && (inputs[0].lens()[kdim] % 4) != 0)
{
MIGRAPHX_THROW("GPU_GEMM: size of A {" + to_string_range(inputs[0].lens()) +
"} and B {" + to_string_range(inputs[1].lens()) +
"} must be multiple of 4 for int8 type");
}
return op.compute_shape(in_shapes); return op.compute_shape(in_shapes);
} }
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
gemm(ctx, output_shape, args, op.alpha, op.beta); gemm(ctx, output_shape, args, op.alpha, op.beta, int8_x4_format);
return args.back(); return args.back();
} }
......
...@@ -13,12 +13,14 @@ void gemm(context& ctx, ...@@ -13,12 +13,14 @@ void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta); float beta,
bool int8_x4_format);
void gemm(context& ctx, void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta); int32_t beta,
bool int8_x4_format);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -13,7 +13,7 @@ namespace gpu { ...@@ -13,7 +13,7 @@ namespace gpu {
struct pack_int8_args struct pack_int8_args
{ {
std::string name() const { return "gpu::pack_int8_args"; } std::string name() const { return "gpu::pack_int8_args"; }
void apply(module& p) const; void apply(module& m) const;
shape pack_int8_shape(const shape& s) const; shape pack_int8_shape(const shape& s) const;
}; };
......
...@@ -55,7 +55,8 @@ struct miopen_apply ...@@ -55,7 +55,8 @@ struct miopen_apply
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{}; instruction_ref last{};
std::unordered_map<instruction_ref, std::string> prog_output_names{}; std::unordered_map<instruction_ref, std::string> prog_output_names{};
bool offload_copy = false; bool offload_copy = false;
bool int8_x4_format = true;
context& get_context() const context& get_context() const
{ {
...@@ -97,6 +98,13 @@ struct miopen_apply ...@@ -97,6 +98,13 @@ struct miopen_apply
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto& ctx = get_context();
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
offload_copy = (mod->name() == "main") ? pass->offload_copy : false; offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
create_output_names(); create_output_names();
...@@ -314,7 +322,8 @@ struct miopen_apply ...@@ -314,7 +322,8 @@ struct miopen_apply
} }
} }
return mod->replace_instruction(ins, rocblas_gemm<Op>{Op{op.alpha, beta}}, refs); return mod->replace_instruction(
ins, rocblas_gemm<Op>{Op{op.alpha, beta}, int8_x4_format}, refs);
}); });
} }
......
#include <iterator>
#include <migraphx/gpu/pack_int8_args.hpp> #include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/int8_gemm_pack.hpp> #include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp> #include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/permutation.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void pack_int8_args::apply(module& p) const static instruction_ref pad_ins(module& m, instruction_ref ins, int offset)
{ {
for(auto ins : iterator_for(p)) auto s = ins->get_shape();
auto lens = s.lens();
auto k = lens[lens.size() + offset];
auto pad_k = (k + 3) / 4 * 4;
auto pad_lens = lens;
pad_lens[lens.size() + offset] = pad_k;
std::vector<int64_t> pad_dims(lens.size() * 2, 0);
auto ret_ins = ins;
if(pad_k != k)
{
pad_dims[lens.size() + offset] = pad_k - k;
shape ps{s.type(), pad_lens};
auto ins_out =
m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(ps)}}));
auto pad = make_op("pad", {{"pads", pad_dims}});
ret_ins =
m.insert_instruction(std::next(ins), make_op("gpu::pad", pad.to_value()), ins, ins_out);
}
return ret_ins;
}
static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins)
{
std::vector<instruction_ref> ret_inputs;
auto inputs = ins->inputs();
auto in0 = inputs.at(0);
auto sa = in0->get_shape();
bool transa = sa.transposed();
if(transa)
{
auto perm = find_permutation(sa);
auto val = in0->get_operator().to_value();
if(val.contains("dims"))
{
int offset = static_cast<int>(perm.back()) - static_cast<int>(perm.size());
auto t_in = in0->inputs().front();
auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>();
auto r_in = m.insert_instruction(ins, make_op("transpose", {{"dims", dims}}), p_in);
ret_inputs.push_back(r_in);
}
else
{
shape cs{in0->get_shape().type(), in0->get_shape().lens()};
auto con_out =
m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(cs)}}));
auto cin0 = m.insert_instruction(ins, make_op("gpu::contiguous"), in0, con_out);
ret_inputs.push_back(pad_ins(m, cin0, -1));
}
}
else
{
ret_inputs.push_back(pad_ins(m, in0, -1));
}
auto in1 = inputs.at(1);
auto sb = in1->get_shape();
bool transb = sb.transposed();
if(transb)
{
auto perm = find_permutation(sb);
auto val = in1->get_operator().to_value();
if(val.contains("dims"))
{
int offset = static_cast<int>(perm[perm.size() - 2]) - static_cast<int>(perm.size());
auto t_in = in1->inputs().front();
auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>();
auto r_in = m.insert_instruction(ins, make_op("transpose", {{"dims", dims}}), p_in);
ret_inputs.push_back(r_in);
}
else
{
shape cs{in1->get_shape().type(), in1->get_shape().lens()};
auto con_out =
m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(cs)}}));
auto cin1 = m.insert_instruction(ins, make_op("gpu::contiguous"), in1, con_out);
ret_inputs.push_back(pad_ins(m, cin1, -2));
}
}
else
{
ret_inputs.push_back(pad_ins(m, in1, -2));
}
std::copy(inputs.begin() + 2, inputs.end(), std::back_inserter(ret_inputs));
return ret_inputs;
}
void pack_int8_args::apply(module& m) const
{
for(auto ins : iterator_for(m))
{ {
if(ins->name() == "gpu::quant_gemm") if(ins->name() == "gpu::quant_gemm")
{ {
auto val = ins->get_operator().to_value();
assert(val.contains("int8_x4_format"));
if(not val.at("int8_x4_format").to<bool>())
{
return;
}
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto lens = inputs.at(0)->get_shape().lens();
// gemm need the k to be multiple of 4, so need packing that dimension
auto old_inputs = inputs;
if((lens.back() % 4) != 0)
{
inputs = pad_inputs(m, ins);
}
bool transa = inputs[0]->get_shape().transposed(); bool transa = inputs[0]->get_shape().transposed();
bool transb = inputs[1]->get_shape().transposed(); bool transb = inputs[1]->get_shape().transposed();
if(!transb) if(!transb)
{ {
auto packed_b = p.insert_instruction(ins, hip_allocate{inputs[1]->get_shape()}); auto packed_b = m.insert_instruction(
auto output_b = ins, make_op("hip::allocate", {{"shape", to_value(inputs[1]->get_shape())}}));
p.insert_instruction(ins, hip_int8_gemm_pack_a{}, {inputs[1], packed_b}); auto output_b = m.insert_instruction(
instruction::replace_argument(ins, inputs[1], output_b); ins, make_op("gpu::int8_gemm_pack_a"), {inputs[1], packed_b});
inputs[1] = output_b;
} }
if(transa) if(transa)
{ {
auto packed_a = p.insert_instruction(ins, hip_allocate{inputs[0]->get_shape()}); auto packed_a = m.insert_instruction(
auto output_a = ins, make_op("hip::allocate", {{"shape", to_value(inputs[0]->get_shape())}}));
p.insert_instruction(ins, hip_int8_gemm_pack_b{}, {inputs[0], packed_a}); auto output_a = m.insert_instruction(
instruction::replace_argument(ins, inputs[0], output_a); ins, make_op("gpu::int8_gemm_pack_b"), {inputs[0], packed_a});
inputs[0] = output_a;
}
if(inputs != old_inputs)
{
m.replace_instruction(ins, ins->get_operator(), inputs);
} }
} }
else if(ins->name() == "gpu::quant_convolution") else if(ins->name() == "gpu::quant_convolution")
{ {
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto packed_x = auto packed_x = m.insert_instruction(
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[0]->get_shape())}); ins,
make_op("hip::allocate",
{{"shape", to_value(pack_int8_shape(inputs[0]->get_shape()))}}));
auto output_x = auto output_x =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[0], packed_x}); m.insert_instruction(ins, make_op("gpu::int8_conv_pack"), {inputs[0], packed_x});
instruction::replace_argument(ins, inputs[0], output_x); instruction::replace_argument(ins, inputs[0], output_x);
auto packed_w = auto packed_w = m.insert_instruction(
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[1]->get_shape())}); ins,
make_op("hip::allocate",
{{"shape", to_value(pack_int8_shape(inputs[1]->get_shape()))}}));
auto output_w = auto output_w =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[1], packed_w}); m.insert_instruction(ins, make_op("gpu::int8_conv_pack"), {inputs[1], packed_w});
instruction::replace_argument(ins, inputs[1], output_w); instruction::replace_argument(ins, inputs[1], output_w);
} }
} }
......
#include "migraphx/instruction_ref.hpp"
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_passes(migraphx::module& m)
{
auto ctx = migraphx::gpu::context{};
migraphx::run_passes(m,
{migraphx::auto_contiguous{},
migraphx::gpu::lowering{&ctx, false},
migraphx::dead_code_elimination{},
migraphx::gpu::pack_int8_args{},
migraphx::dead_code_elimination{}});
}
bool get_int8_x4_format()
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto ctx = migraphx::gpu::context{};
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
}
TEST_CASE(quant_dot)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape m1_shape{migraphx::shape::int8_type, {5, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", m1_shape);
auto l2 = m.add_parameter("b", m2_shape);
auto l3 = m.add_parameter("c", m3_shape);
auto r = m.add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape m1_shape{migraphx::shape::int8_type, {5, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", m1_shape);
auto l2 = m.add_parameter("b", m2_shape);
auto l3 = m.add_parameter("c", m3_shape);
auto output = m.add_parameter("test:#output_0", m3_shape);
auto cout = m.add_instruction(migraphx::make_op("hip::copy"), l3, output);
auto packa = l2;
if(int8_x4)
{
auto alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(m2_shape)}}));
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), l2, alloc);
}
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm",
{{"alpha", 1}, {"beta", 1}, {"int8_x4_format", int8_x4}}),
l1,
packa,
cout,
cout);
m.add_return({gemm});
return m;
};
auto m1 = create_module();
run_passes(m1);
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_trans)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 8, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}};
auto l1 = m.add_parameter("a", s1);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 8, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape s3{migraphx::shape::int32_type, {3, 2, 5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 8}};
auto alloca = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, alloca);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 8, 7}};
auto allocb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, allocb);
auto packb = contb;
if(int8_x4)
{
auto allocpb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), contb, allocpb);
}
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm",
{{"alpha", 3}, {"beta", 0}, {"int8_x4_format", int8_x4}}),
conta,
packb,
output);
m.add_return({gemm});
return m;
};
auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
run_passes(m1);
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_pad)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {5, 6}};
migraphx::shape s2{migraphx::shape::int8_type, {6, 7}};
migraphx::shape s3{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto l3 = m.add_parameter("c", s3);
auto r = m.add_instruction(migraphx::make_op("quant_dot"), l1, l2, l3);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {5, 6}};
migraphx::shape ps1{migraphx::shape::int8_type, {5, 8}};
migraphx::shape s2{migraphx::shape::int8_type, {6, 7}};
migraphx::shape ps2{migraphx::shape::int8_type, {8, 7}};
migraphx::shape s3{migraphx::shape::int32_type, {5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto l3 = m.add_parameter("c", s3);
auto output = m.add_parameter("test:#output_0", s3);
auto pl1 = l1;
auto packa = l2;
migraphx::instruction_ref pl2{};
if(int8_x4)
{
auto po1 = m.insert_instruction(
l1, migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}}));
pl1 = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 2, 0, 0}}, {"value", 0}}),
l1,
po1);
auto po2 = m.insert_instruction(
l2, migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
pl2 = m.insert_instruction(
std::next(l2),
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {2, 0, 0, 0}}, {"value", 0}}),
l2,
po2);
}
auto cout = m.add_instruction(migraphx::make_op("hip::copy"), l3, output);
if(int8_x4)
{
auto alloc = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
packa = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pl2, alloc);
}
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm",
{{"alpha", 1}, {"beta", 1}, {"int8_x4_format", int8_x4}}),
pl1,
packa,
cout,
cout);
m.add_return({gemm});
return m;
};
auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
run_passes(m1);
EXPECT(m1 == m2);
}
TEST_CASE(quant_dot_trans_pad)
{
auto create_module = [] {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 9, 5}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}};
auto l1 = m.add_parameter("a", s1);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto l2 = m.add_parameter("b", s2);
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto r = m.add_instruction(
migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
m.add_return({r});
return m;
};
auto create_optimized_int8_x4 = [](bool int8_x4) {
migraphx::module m("test");
migraphx::shape s1{migraphx::shape::int8_type, {3, 2, 9, 5}};
migraphx::shape ps1{migraphx::shape::int8_type, {3, 2, 5, 12}};
migraphx::shape s2{migraphx::shape::int8_type, {3, 2, 7, 9}};
migraphx::shape ps2{migraphx::shape::int8_type, {3, 2, 12, 7}};
migraphx::shape s3{migraphx::shape::int32_type, {3, 2, 5, 7}};
auto l1 = m.add_parameter("a", s1);
auto l2 = m.add_parameter("b", s2);
auto output = m.add_parameter("test:#output_0", s3);
auto tl1 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
migraphx::shape ts1{migraphx::shape::int8_type, {3, 2, 5, 9}};
auto ta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts1)}}));
migraphx::instruction_ref pta{};
if(int8_x4)
{
pta = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps1)}}));
}
auto conta = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl1, ta);
auto pa = conta;
if(int8_x4)
{
pa = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 0, 3, 0, 0, 0, 0}}}),
conta,
pta);
}
auto tl2 = m.add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
migraphx::shape ts2{migraphx::shape::int8_type, {3, 2, 9, 7}};
auto tb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ts2)}}));
migraphx::instruction_ref ptb{};
if(int8_x4)
{
ptb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
}
auto contb = m.add_instruction(migraphx::make_op("gpu::contiguous"), tl2, tb);
auto packb = contb;
if(int8_x4)
{
auto pb = m.add_instruction(
migraphx::make_op("gpu::pad", {{"mode", 0}, {"pads", {0, 0, 3, 0, 0, 0, 0, 0}}}),
contb,
ptb);
auto allocpb = m.add_instruction(
migraphx::make_op("hip::allocate", {{"shape", migraphx::to_value(ps2)}}));
packb = m.add_instruction(migraphx::make_op("gpu::int8_gemm_pack_a"), pb, allocpb);
}
auto gemm = m.add_instruction(
migraphx::make_op("gpu::quant_gemm",
{{"alpha", 3}, {"beta", 0}, {"int8_x4_format", int8_x4}}),
pa,
packb,
output);
m.add_return({gemm});
return m;
};
auto m1 = create_module();
bool flag = get_int8_x4_format();
auto m2 = create_optimized_int8_x4(flag);
run_passes(m1);
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -63,9 +63,9 @@ TEST_CASE(int8_quantization) ...@@ -63,9 +63,9 @@ TEST_CASE(int8_quantization)
auto create_program = [] { auto create_program = [] {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sa{migraphx::shape::float_type, {5, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc); auto pc = mm->add_parameter("c", sc);
...@@ -77,9 +77,9 @@ TEST_CASE(int8_quantization) ...@@ -77,9 +77,9 @@ TEST_CASE(int8_quantization)
{ {
auto p = create_program(); auto p = create_program();
migraphx::parameter_map m; migraphx::parameter_map m;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sa{migraphx::shape::float_type, {5, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}}; migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
m["a"] = migraphx::generate_argument(sa); m["a"] = migraphx::generate_argument(sa);
m["b"] = migraphx::generate_argument(sb); m["b"] = migraphx::generate_argument(sb);
m["c"] = migraphx::generate_argument(sc); m["c"] = migraphx::generate_argument(sc);
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), l1, l2);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 0, 1, 2}}}), l1);
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {3, 1, 2, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 3}}), tl1, tl2);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}};
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l1);
auto sl1 = mm->add_instruction(migraphx::make_op("add"), tl1, tl1);
auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), l2);
auto sl2 = mm->add_instruction(migraphx::make_op("add"), tl2, tl2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 1}}), sl1, sl2);
return p;
}
};
...@@ -45,5 +45,14 @@ int main(int argc, const char* argv[]) ...@@ -45,5 +45,14 @@ int main(int argc, const char* argv[])
run_verify rv; run_verify rv;
rv.add_validation_for("gpu", &validate_gpu); rv.add_validation_for("gpu", &validate_gpu);
rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"}); rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"});
rv.disable_test_for("gpu",
{"batch_quant_dot_2",
"batch_quant_dot_3",
"batch_quant_dot_5",
"quant_dot_3args_1",
"quant_dot_3args_2",
"quant_dot_3args_3",
"quant_dot_3args_4",
"quant_dot_3args_5"});
rv.run(argc, argv); rv.run(argc, argv);
} }
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {6, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 6}};
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l2);
mm->add_instruction(migraphx::make_op("quant_dot", {{"alpha", 3}, {"beta", 2}}), tl1, tl2);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment