Commit 7309f43a authored by Paul's avatar Paul
Browse files

Enable fast mode by default

parent 0039b11a
...@@ -164,15 +164,21 @@ auto is_mlir_dot(mlir_mode mode) ...@@ -164,15 +164,21 @@ auto is_mlir_dot(mlir_mode mode)
return false; return false;
if(mode != mlir_mode::fast) if(mode != mlir_mode::fast)
return true; return true;
auto a = ins->inputs().front()->get_shape(); float a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape(); float b = ins->inputs().back()->get_shape();
// auto m = a.lens()[a.lens().size() - 2]; float m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back(); float n = b.lens().back();
auto k = a.lens().back(); float k = a.lens().back();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy if (k > 1024)
// to avoid poor-performing GEMM kernels from MLIR return false;
// To-do: Investigate a more precise strategy auto ratio = m*n/k;
return k <= 2048; if (ratio < 16384)
return false;
return true;
// // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// // to avoid poor-performing GEMM kernels from MLIR
// // To-do: Investigate a more precise strategy
// return k <= 2048;
}); });
} }
...@@ -418,7 +424,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const ...@@ -418,7 +424,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match::find_matches(mpm, match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast), find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)}); .dot_mode = get_mode("fused", mlir_mode::fast)});
match::find_matches( match::find_matches(
mpm, mpm,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment