"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "0af1d239cbc6da81c9e6dc64bfb56b4a917e3a1f"
Unverified Commit f47e0b5b authored by turneram's avatar turneram Committed by GitHub
Browse files

CK GEMM Int8 Bug Fixes (#2229)

Adds workarounds to avoid passing capture ops and scalar literals from quantization as arguments to ck_gemm.
parent b8b4630b
...@@ -33,6 +33,8 @@ ...@@ -33,6 +33,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS);
void apply_quantizelinear(module& m, instruction_ref ins) void apply_quantizelinear(module& m, instruction_ref ins)
{ {
assert(ins->name() == "quantizelinear"); assert(ins->name() == "quantizelinear");
...@@ -63,8 +65,21 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -63,8 +65,21 @@ void apply_quantizelinear(module& m, instruction_ref ins)
min_quant = qt.min(); min_quant = qt.min();
}); });
auto s = add_zero_point->get_shape(); auto s = add_zero_point->get_shape();
auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}}); instruction_ref min_arg;
auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}}); instruction_ref max_arg;
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data));
}
else
{
min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
}
auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg}); auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
m.replace_instruction( m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate); ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/gpu/device_name.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -92,6 +93,8 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -92,6 +93,8 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
auto m = a.lens()[a.lens().size() - 2]; auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back(); auto n = b.lens().back();
auto k = a.lens().back(); auto k = a.lens().back();
auto batch_size = std::accumulate(
a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
// Integer gemms must be divisible by 4 in ck // Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type())) if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{ {
...@@ -102,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -102,9 +105,17 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
if(k % 4 != 0) if(k % 4 != 0)
return false; return false;
} }
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy auto device_name = trim(split_string(get_device_name(), ':').front());
// to avoid poor-performing GEMM kernels from CK if(device_name == "gfx940")
// To-do: Investigate a more precise strategy {
if(ins->get_shape().type() == shape::half_type)
{
if(batch_size >= 64)
return m < 2048 or k <= 64 or n <= 384 or n >= 2048;
return true;
}
return true;
}
return k <= 2048; return k <= 2048;
} }
...@@ -140,6 +151,10 @@ struct find_ck_gemm_pointwise ...@@ -140,6 +151,10 @@ struct find_ck_gemm_pointwise
return not input->inputs().empty() and input->inputs().front()->name() == "capture"; return not input->inputs().empty() and input->inputs().front()->name() == "capture";
})) }))
return; return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto input) {
return not input->inputs().empty() and input->inputs().front()->name() == "capture";
}))
return;
assert(gemm_it != inputs.end()); assert(gemm_it != inputs.end());
if(gemm_idx != 0) if(gemm_idx != 0)
{ {
......
...@@ -42,11 +42,14 @@ ...@@ -42,11 +42,14 @@
#include <migraphx/op/lrn.hpp> #include <migraphx/op/lrn.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/env.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include "test.hpp" #include "test.hpp"
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS);
migraphx::program optimize_onnx(const std::string& name, bool run_passes = false) migraphx::program optimize_onnx(const std::string& name, bool run_passes = false)
{ {
migraphx::onnx_options options; migraphx::onnx_options options;
...@@ -5540,6 +5543,31 @@ TEST_CASE(qlinearmatmul_2D_test) ...@@ -5540,6 +5543,31 @@ TEST_CASE(qlinearmatmul_2D_test)
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
migraphx::instruction_ref insert_quantizelinear_clip(migraphx::module& m,
const migraphx::instruction_ref ins,
const migraphx::instruction_ref round,
const migraphx::shape s,
const int64_t min_quant,
const int64_t max_quant)
{
migraphx::instruction_ref min_arg;
migraphx::instruction_ref max_arg;
if(migraphx::enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
min_arg = m.add_literal(migraphx::literal(s, min_data));
max_arg = m.add_literal(migraphx::literal(s, max_data));
}
else
{
min_arg = m.add_literal(migraphx::literal{migraphx::shape{s.type()}, {min_quant}});
max_arg = m.add_literal(migraphx::literal{migraphx::shape{s.type()}, {max_quant}});
}
return migraphx::insert_common_op(m, ins, migraphx::make_op("clip"), {round, min_arg, max_arg});
}
TEST_CASE(quantizelinear_test) TEST_CASE(quantizelinear_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -5551,13 +5579,7 @@ TEST_CASE(quantizelinear_test) ...@@ -5551,13 +5579,7 @@ TEST_CASE(quantizelinear_test)
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div); auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape(); auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {0}}); auto clip = insert_quantizelinear_clip(*mm, div, round, s, 0, 255);
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {255}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_mbcast, max_mbcast);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
...@@ -5582,13 +5604,7 @@ TEST_CASE(quantizelinear_int32_test) ...@@ -5582,13 +5604,7 @@ TEST_CASE(quantizelinear_int32_test)
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div); auto round = mm->add_instruction(migraphx::make_op("round"), div);
auto s = round->get_shape(); auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {0}}); auto clip = insert_quantizelinear_clip(*mm, div, round, s, 0, 255);
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {255}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), round, min_mbcast, max_mbcast);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
...@@ -5617,13 +5633,7 @@ TEST_CASE(quantizelinear_zero_point_test) ...@@ -5617,13 +5633,7 @@ TEST_CASE(quantizelinear_zero_point_test)
l2_mbcast); l2_mbcast);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast); auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_mbcast);
auto s = round->get_shape(); auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {-128}}); auto clip = insert_quantizelinear_clip(*mm, div, add, s, -128, 127);
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {127}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_mbcast, max_mbcast);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
...@@ -5656,13 +5666,7 @@ migraphx::program make_quantizelinear_axis_prog() ...@@ -5656,13 +5666,7 @@ migraphx::program make_quantizelinear_axis_prog()
l2_bcast); l2_bcast);
auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast); auto add = mm->add_instruction(migraphx::make_op("add"), round, l2_bcast);
auto s = round->get_shape(); auto s = round->get_shape();
auto min_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {-128}}); auto clip = insert_quantizelinear_clip(*mm, div, add, s, -128, 127);
auto max_arg = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {127}});
auto min_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), min_arg);
auto max_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), max_arg);
auto clip = mm->add_instruction(migraphx::make_op("clip"), add, min_mbcast, max_mbcast);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::int8_type)}}),
......
...@@ -31,10 +31,13 @@ ...@@ -31,10 +31,13 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <test.hpp> #include <test.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK_WORKAROUNDS);
bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; } bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; }
bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; } bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; }
bool is_clip_scalar(migraphx::instruction& ins) bool is_clip_scalar(migraphx::instruction& ins)
...@@ -82,6 +85,10 @@ TEST_CASE(quantizelinear) ...@@ -82,6 +85,10 @@ TEST_CASE(quantizelinear)
EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear)); EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear)); EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear));
// ensure clip literals created in quantized program are scalar // ensure clip literals created in quantized program are scalar
// unless CK workarounds are enabled
if(migraphx::enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
EXPECT(none_of(*p2.get_main_module(), &is_clip_scalar));
else
EXPECT(any_of(*p2.get_main_module(), &is_clip_scalar)); EXPECT(any_of(*p2.get_main_module(), &is_clip_scalar));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment