Commit cb9bfaf4 authored by Paul's avatar Paul
Browse files

Add mul_add_relu fusion

parent 3135fc93
...@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto a, auto b) { return std::max<decltype(a*x + b)>(0, a*x + b); });
}
void add_relu(hipStream_t stream, void add_relu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
......
...@@ -220,6 +220,25 @@ struct hip_mul_add ...@@ -220,6 +220,25 @@ struct hip_mul_add
} }
}; };
struct hip_mul_add_relu
{
std::string name() const { return "hip::mul_add_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
void move_broadcasted_back(std::vector<instruction_ref>& args) void move_broadcasted_back(std::vector<instruction_ref>& args)
{ {
// Ensure the last arguments is the broadcasted one // Ensure the last arguments is the broadcasted one
...@@ -325,6 +344,25 @@ struct find_mul_add ...@@ -325,6 +344,25 @@ struct find_mul_add
} }
}; };
struct find_mul_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(match::arg(0)(match::name("hip::mul_add").bind("mul_add")));
}
void apply(program& p, match::matcher_result r) const
{
auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result;
auto args = mul_add_ins->inputs();
// Use the allocation from the relu operator
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add_relu{}, args);
}
};
struct miopen_conv_bias struct miopen_conv_bias
{ {
op::convolution op; op::convolution op;
...@@ -480,8 +518,9 @@ void fuse_ops::apply(program& p) const ...@@ -480,8 +518,9 @@ void fuse_ops::apply(program& p) const
match::find_matches(p, match::find_matches(p,
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_relu{}, find_mul_add{},
find_mul_add{} find_mul_add_relu{},
find_add_relu{}
); );
// clang-format on // clang-format on
} }
......
...@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
void add_relu(hipStream_t stream, void add_relu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment