"...compression/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "5f571327902c84c208482f66c2b293ad1013ee3d"
Commit 6e7c8be8 authored by Paul's avatar Paul
Browse files

Add fuse_mlir pass

parent 7f65a88e
...@@ -140,6 +140,7 @@ add_library(migraphx_gpu ...@@ -140,6 +140,7 @@ add_library(migraphx_gpu
device_name.cpp device_name.cpp
eliminate_workspace.cpp eliminate_workspace.cpp
elu.cpp elu.cpp
fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp gather.cpp
gemm_impl.cpp gemm_impl.cpp
......
#ifndef MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP
#define MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct fuse_mlir
{
context* ctx = nullptr;
std::string name() const { return "gpu::fuse_mlir"; }
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_MLIR_HPP
...@@ -509,8 +509,11 @@ std::string dump_mlir(const module& m) ...@@ -509,8 +509,11 @@ std::string dump_mlir(const module& m)
code_object_op compile_mlir(const module& m) code_object_op compile_mlir(const module& m)
{ {
std::cout << m << std::endl;
mlir_program mp; mlir_program mp;
mp.parse(m); mp.parse(m);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
return mp.compile(); return mp.compile();
} }
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/eliminate_workspace.hpp> #include <migraphx/gpu/eliminate_workspace.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/mlir_conv.hpp> #include <migraphx/gpu/mlir_conv.hpp>
...@@ -102,7 +103,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -102,7 +103,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
fuse_pointwise{}, fuse_pointwise{},
dead_code_elimination{}, dead_code_elimination{},
mlir_conv{&ctx}, fuse_mlir{&ctx},
dead_code_elimination{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
...@@ -112,8 +114,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -112,8 +114,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}}, adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, // fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, // dead_code_elimination{},
compile_ops{&ctx}, compile_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment