Commit e0724664 authored by Paul's avatar Paul
Browse files

Enable pass

parent 0ee0d6e7
...@@ -93,6 +93,7 @@ add_library(migraphx_gpu ...@@ -93,6 +93,7 @@ add_library(migraphx_gpu
compile_miopen.cpp compile_miopen.cpp
compiler.cpp compiler.cpp
device_name.cpp device_name.cpp
fuse_ck.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp gather.cpp
......
...@@ -6,9 +6,6 @@ ...@@ -6,9 +6,6 @@
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK);
struct module; struct module;
namespace gpu { namespace gpu {
...@@ -122,11 +119,8 @@ struct find_ck_gemm ...@@ -122,11 +119,8 @@ struct find_ck_gemm
void fuse_ck::apply(module_pass_manager& mpm) const void fuse_ck::apply(module_pass_manager& mpm) const
{ {
if(enabled(MIGRAPHX_ENABLE_CK{}))
{
match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{}); match::find_matches(mpm, find_ck_gemm{});
}
} }
} // namespace gpu } // namespace gpu
......
...@@ -57,6 +57,7 @@ ...@@ -57,6 +57,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_mlir.hpp> #include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
...@@ -74,6 +75,7 @@ namespace gpu { ...@@ -74,6 +75,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
struct id_pass struct id_pass
{ {
std::string name() const { return "id"; } std::string name() const { return "id"; }
...@@ -133,6 +135,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -133,6 +135,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{}, dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}), enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{}, dead_code_elimination{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment