Commit 7309f43a authored by Paul's avatar Paul
Browse files

Enable fast mode by default

parent 0039b11a
......@@ -164,15 +164,21 @@ auto is_mlir_dot(mlir_mode mode)
return false;
if(mode != mlir_mode::fast)
return true;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
// auto m = a.lens()[a.lens().size() - 2];
// auto n = b.lens().back();
auto k = a.lens().back();
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from MLIR
// To-do: Investigate a more precise strategy
return k <= 2048;
float a = ins->inputs().front()->get_shape();
float b = ins->inputs().back()->get_shape();
float m = a.lens()[a.lens().size() - 2];
float n = b.lens().back();
float k = a.lens().back();
if (k > 1024)
return false;
auto ratio = m*n/k;
if (ratio < 16384)
return false;
return true;
// // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// // to avoid poor-performing GEMM kernels from MLIR
// // To-do: Investigate a more precise strategy
// return k <= 2048;
});
}
......@@ -418,7 +424,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match::find_matches(mpm,
find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
.dot_mode = get_mode("fused", mode)});
.dot_mode = get_mode("fused", mlir_mode::fast)});
match::find_matches(
mpm,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment