Commit e4441183 authored by Paul's avatar Paul
Browse files

Add a pass to fuse add and relu

parent 6f96cf7e
......@@ -22,6 +22,7 @@ target_include_directories(migraph_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURR
add_library(migraph_gpu
eliminate_allocation.cpp
eliminate_workspace.cpp
fuse_ops.cpp
hip.cpp
target.cpp
lowering.cpp
......
#include <migraph/gpu/fuse_ops.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/gpu/device/add_relu.hpp>
#include <migraph/instruction.hpp>
namespace migraph {
namespace gpu {
struct hip_add_relu
{
std::string name() const { return "hip::add_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs}.has(3).standard();
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add_relu(args.at(0), args.at(1), args.at(2));
return args.at(2);
}
};
void fuse_ops::apply(program& p) const
{
assert(ctx != nullptr);
for(auto ins : iterator_for(p))
{
if(ins->op.name() != "gpu::relu")
continue;
auto add_ins = ins->arguments.front();
if(add_ins->op.name() != "gpu::add")
continue;
p.replace_instruction(ins, hip_add_relu{}, add_ins->arguments);
}
}
} // namespace gpu
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_FUSE_OPS_HPP
#define MIGRAPH_GUARD_RTGLIB_FUSE_OPS_HPP
#include <migraph/program.hpp>
#include <migraph/gpu/context.hpp>
namespace migraph {
namespace gpu {
struct fuse_ops
{
std::string name() const { return "gpu::fuse_ops"; }
void apply(program& p) const;
};
} // namespace gpu
} // namespace migraph
#endif
......@@ -4,6 +4,7 @@
#include <migraph/gpu/context.hpp>
#include <migraph/gpu/eliminate_workspace.hpp>
#include <migraph/gpu/eliminate_allocation.hpp>
#include <migraph/gpu/fuse_ops.hpp>
#include <migraph/check_context.hpp>
#include <migraph/auto_contiguous.hpp>
#include <migraph/dead_code_elimination.hpp>
......@@ -24,6 +25,8 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
simplify_reshapes{},
dead_code_elimination{},
lowering{ctx},
fuse_ops{},
dead_code_elimination{},
eliminate_workspace{},
eliminate_contiguous{},
dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment