Unverified Commit 650ba45f authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Enable MLIR by default for more cases (#2274)

This will enable MLIR by default for these cases:

Any convolution fusion
Any int8 gemm fusion
All Navi3 standalone convolutions
With a flag(ie MIGRAPHX_ENABLE_MLIR) to enable MLIR for floating-point gemm fusions
Except:

3x3 winnograd convolutions fusions (except on Navi)
K > 2048 on gemm (as CK)
Also there is MIGRAPHX_DISABLE_MLIR to disable MLIR completely.
parent f8bf7bd3
...@@ -109,10 +109,13 @@ def rocmnode(name, body) { ...@@ -109,10 +109,13 @@ def rocmnode(name, body) {
rocmtest clang_debug: rocmnode('mi100+') { cmake_build -> rocmtest clang_debug: rocmnode('mi100+') { cmake_build ->
stage('hipRTC Debug') { stage('hipRTC Debug') {
def sanitizers = "undefined" // Disable MLIR since it doesnt work with all ub sanitizers
def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" withEnv(['MIGRAPHX_DISABLE_MLIR=1']) {
def gpu_targets = getgputargets() def sanitizers = "undefined"
cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}' -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'", gpu_debug: true) def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
def gpu_targets = getgputargets()
cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}' -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'", gpu_debug: true)
}
} }
}, clang_release: rocmnode('mi100+') { cmake_build -> }, clang_release: rocmnode('mi100+') { cmake_build ->
stage('Hip Clang Release') { stage('Hip Clang Release') {
...@@ -131,7 +134,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build -> ...@@ -131,7 +134,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build ->
} }
}, mlir_debug: rocmnode('mi100+') { cmake_build -> }, mlir_debug: rocmnode('mi100+') { cmake_build ->
stage('MLIR Debug') { stage('MLIR Debug') {
withEnv(['MIGRAPHX_ENABLE_MLIR=1']) { withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1']) {
def sanitizers = "undefined" def sanitizers = "undefined"
// Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS. // Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS.
def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}" def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}"
...@@ -142,7 +145,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build -> ...@@ -142,7 +145,7 @@ rocmtest clang_debug: rocmnode('mi100+') { cmake_build ->
} }
}, ck_hiprtc: rocmnode('mi100+') { cmake_build -> }, ck_hiprtc: rocmnode('mi100+') { cmake_build ->
stage('CK hipRTC') { stage('CK hipRTC') {
withEnv(['MIGRAPHX_ENABLE_CK=1', 'MIGRAPHX_TUNE_CK=1']) { withEnv(['MIGRAPHX_ENABLE_CK=1', 'MIGRAPHX_TUNE_CK=1', 'MIGRAPHX_DISABLE_MLIR=1']) {
def gpu_targets = getgputargets() def gpu_targets = getgputargets()
cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'") cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'")
} }
......
...@@ -36,24 +36,14 @@ struct module; ...@@ -36,24 +36,14 @@ struct module;
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
bool mlir_enabled() bool mlir_enabled()
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{}); const bool mlir_disabled = enabled(MIGRAPHX_DISABLE_MLIR{});
if(mlir_enabled) return not mlir_disabled;
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
#else #else
return false; return false;
#endif #endif
...@@ -157,27 +147,72 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) ...@@ -157,27 +147,72 @@ fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
return {new_gemm_based_op, top_inputs}; return {new_gemm_based_op, top_inputs};
} }
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) enum class mlir_mode
{ {
if(ins->name() != "convolution" and ins->name() != "quant_convolution") all,
return false; fast,
value v = ins->get_operator().to_value(); int8,
auto group = v.at("group").to<int>(); none
if(group != 1) };
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!" auto is_mlir_dot(mlir_mode mode)
if(ins->get_shape().lens().size() != 4) {
return false; return match::make_basic_pred_matcher([=](instruction_ref ins) {
return true; if(mode == mlir_mode::none)
return false;
if(ins->name() != "dot" and ins->name() != "quant_dot")
return false;
if(mode != mlir_mode::fast)
return true;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// auto m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back();
auto k = a.lens().back();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from MLIR
// To-do: Investigate a more precise strategy
return k <= 2048;
});
}
auto is_mlir_conv(mlir_mode mode)
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(mode == mlir_mode::none)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode == mlir_mode::int8)
return false;
if(mode == mlir_mode::all)
return true;
auto w = ins->inputs().at(1)->get_shape();
if(w.lens().size() != 4)
return true;
if(w.lens()[2] != w.lens()[3])
return true;
return (w.lens()[3] % 3) != 0;
});
} }
struct find_mlir_fused_ops struct find_mlir_fused_ops
{ {
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const auto matcher() const
{ {
auto dot_or_conv = match::skip(match::name("contiguous"))( auto dot_or_conv = match::skip(match::name("contiguous"))(
match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv()) match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
.bind("gemm_based_op"));
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
...@@ -309,8 +344,11 @@ struct find_mlir_fused_ops ...@@ -309,8 +344,11 @@ struct find_mlir_fused_ops
} }
}; };
template <auto Matcher>
struct find_mlir_standalone_op struct find_mlir_standalone_op
{ {
mlir_mode mode = mlir_mode::none;
auto matcher() const { return Matcher(mode); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto conv_based_op = r.result; auto conv_based_op = r.result;
...@@ -332,15 +370,8 @@ struct find_mlir_standalone_op ...@@ -332,15 +370,8 @@ struct find_mlir_standalone_op
} }
}; };
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
{ using find_mlir_standalone_dot_op = find_mlir_standalone_op<&is_mlir_dot>;
auto matcher() const { return is_mlir_conv; }
};
struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
};
/** /**
* @brief Declares a new MIGraphX environment variable which forces to generate * @brief Declares a new MIGraphX environment variable which forces to generate
...@@ -354,44 +385,15 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op ...@@ -354,44 +385,15 @@ struct find_mlir_standalone_dot_op : find_mlir_standalone_op
* intended to be primarily used by rocMLIR developers. * intended to be primarily used by rocMLIR developers.
*/ */
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_self_decide() { return string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "").empty(); }
bool is_requested(std::string_view option) bool is_requested(std::string_view option, bool fallback = false)
{ {
assert(not is_self_decide());
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, ""); auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
if(string_value.empty())
return fallback;
const auto options = split_string(string_value, ','); const auto options = split_string(string_value, ',');
return contains(options, option); return contains(options, option);
} }
bool is_enabled(std::string_view op_name, context* ctx)
{
if(is_self_decide())
{
if(op_name == "fused")
{
return true;
}
else if(op_name == "convolution" or op_name == "quant_convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
}
else
{
return false;
}
}
return is_requested(op_name);
}
} // namespace } // namespace
#endif // MIGRAPHX_MLIR #endif // MIGRAPHX_MLIR
...@@ -399,20 +401,28 @@ bool is_enabled(std::string_view op_name, context* ctx) ...@@ -399,20 +401,28 @@ bool is_enabled(std::string_view op_name, context* ctx)
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
if(is_enabled("fused", this->ctx)) const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name();
{ const bool is_navi = starts_with(device_name, "gfx110");
match::find_matches(mpm, find_mlir_fused_ops{});
}
if(is_enabled("convolution", this->ctx)) auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
{ if(is_requested(option))
match::find_matches(mpm, find_mlir_standalone_convolution_op{}); return mlir_mode::all;
} if(is_navi)
return mlir_mode::all;
return std::max(m1, m2);
};
if(is_enabled("dot", this->ctx)) mlir_mode mode =
{ (enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
match::find_matches(mpm, find_mlir_standalone_dot_op{});
} match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});
match::find_matches(
mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else #else
(void)mpm; (void)mpm;
#endif #endif
......
...@@ -38,6 +38,7 @@ MIGRAPHX_GPU_EXPORT bool mlir_enabled(); ...@@ -38,6 +38,7 @@ MIGRAPHX_GPU_EXPORT bool mlir_enabled();
struct MIGRAPHX_GPU_EXPORT fuse_mlir struct MIGRAPHX_GPU_EXPORT fuse_mlir
{ {
context* ctx = nullptr; context* ctx = nullptr;
bool enable_extra = false;
std::string name() const { return "gpu::fuse_mlir"; } std::string name() const { return "gpu::fuse_mlir"; }
void apply(module_pass_manager& mpm) const; void apply(module_pass_manager& mpm) const;
}; };
......
...@@ -34,7 +34,8 @@ ...@@ -34,7 +34,8 @@
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
{ {
migraphx::run_passes(p, {migraphx::gpu::fuse_mlir{}, migraphx::dead_code_elimination{}}); migraphx::run_passes(
p, {migraphx::gpu::fuse_mlir{.enable_extra = true}, migraphx::dead_code_elimination{}});
} }
template <class F> template <class F>
...@@ -151,7 +152,6 @@ TEST_CASE(int_quant_dot_tanh_fails) ...@@ -151,7 +152,6 @@ TEST_CASE(int_quant_dot_tanh_fails)
int main(int argc, const char* argv[]) int main(int argc, const char* argv[])
{ {
if(migraphx::gpu::mlir_enabled()) test::run(argc, argv);
test::run(argc, argv);
return 0; return 0;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment