Commit ef5e7ce0 authored by kahmed10's avatar kahmed10 Committed by mvermeulen
Browse files

Add fusions for sigmoid and tanh (#354)

* add tests, fix bug in ternary op

* formatting

* uncomment fusion
parent 01615379
......@@ -30,7 +30,7 @@ add_library(migraphx_device
device/acos.cpp
device/atan.cpp
device/relu.cpp
device/add_relu.cpp
device/add_unary.cpp
device/contiguous.cpp
device/logsoftmax.cpp
device/softmax.cpp
......
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_unary.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
......@@ -25,6 +25,23 @@ void add_relu(hipStream_t stream,
[](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); });
}
void add_sigmoid(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{
nary(stream, result, arg1, arg2)(
[](auto x, auto y) { return 1.f / (1.f + ::exp(to_hip_type(-(x + y)))); });
}
void add_tanh(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return ::tanh(to_hip_type(x + y)); });
}
void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
......@@ -35,6 +52,26 @@ void add_relu(hipStream_t stream,
[](auto x, auto y, auto z) { return std::max<decltype(x + y + z)>(0, x + y + z); });
}
void add_sigmoid(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return 1.f / (1.f + ::exp(to_hip_type(-(x + y + z)))); });
}
void add_tanh(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return ::tanh(to_hip_type(x + y + z)); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_unary.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
......
......@@ -2,8 +2,9 @@
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_unary.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/array.hpp>
......@@ -161,42 +162,28 @@ struct hip_triadd
}
};
struct hip_triadd_relu
struct hip_triadd_relu : ternary_device<hip_triadd_relu, &device::add_relu>
{
std::string name() const { return "hip::triadd_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_add_relu
struct hip_triadd_sigmoid : ternary_device<hip_triadd_sigmoid, &device::add_sigmoid>
{
};
struct hip_triadd_tanh : ternary_device<hip_triadd_tanh, &device::add_tanh>
{
};
struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
{
};
struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid>
{
};
struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh>
{
std::string name() const { return "hip::add_relu"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::add_relu(ctx.get_stream().get(), args.at(2), args.at(0), args.at(1));
return args.at(2);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_mul_add
......@@ -258,11 +245,14 @@ void move_standard_front(std::vector<instruction_ref>& args)
std::swap(*it, args.front());
}
struct find_add_relu
struct find_add_unary
{
std::string op_name;
operation binary_add_op;
operation ternary_add_op;
auto matcher() const
{
return match::name("gpu::relu")(match::arg(0)(
return match::name(op_name)(match::arg(0)(
match::used_once(),
match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
......@@ -282,9 +272,9 @@ struct find_add_relu
// Use the allocation from the relu operator
args.back() = ins->inputs().back();
if(add_ins->name() == "gpu::add")
p.replace_instruction(ins, hip_add_relu{}, args);
p.replace_instruction(ins, binary_add_op, args);
else if(add_ins->name() == "hip::triadd")
p.replace_instruction(ins, hip_triadd_relu{}, args);
p.replace_instruction(ins, ternary_add_op, args);
}
};
......@@ -521,7 +511,9 @@ void fuse_ops::apply(program& p) const
find_conv_bias{ctx},
find_mul_add{},
find_mul_add_relu{},
find_add_relu{}
find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}},
find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}
);
// clang-format on
}
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ADD_RELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ADD_RELU_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ADD_UNARY_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ADD_UNARY_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
......@@ -22,12 +22,34 @@ void add_relu(hipStream_t stream,
const argument& arg1,
const argument& arg2);
void add_sigmoid(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2);
void add_tanh(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2);
void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
void add_sigmoid(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
void add_tanh(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -74,6 +74,39 @@ struct binary_device : oper<Derived>
}
};
template <class Derived,
void (*F)(
hipStream_t, const argument&, const argument&, const argument&, const argument&)>
struct ternary_device : oper<Derived>
{
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s0 == s1 and s1 == s2 and s0.packed())
{
return s0;
}
else
{
return {s0.type(), s0.lens()};
}
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
F(ctx.get_stream().get(), args[3], args[0], args[1], args[2]);
return args[3];
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -827,6 +827,77 @@ struct test_add_relu : verify_program<test_add_relu>
}
};
struct test_add_sigmoid : verify_program<test_add_sigmoid>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto add = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraphx::op::sigmoid{}, add);
return p;
}
};
struct test_add_tanh : verify_program<test_add_tanh>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto add = p.add_instruction(migraphx::op::add{}, x, y);
p.add_instruction(migraphx::op::tanh{}, add);
return p;
}
};
struct test_triadd_relu : verify_program<test_triadd_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = p.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = p.add_instruction(migraphx::op::add{}, x, y);
auto triadd = p.add_instruction(migraphx::op::add{}, sum, z);
p.add_instruction(migraphx::op::relu{}, triadd);
return p;
}
};
struct test_triadd_sigmoid : verify_program<test_triadd_sigmoid>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = p.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = p.add_instruction(migraphx::op::add{}, x, y);
auto triadd = p.add_instruction(migraphx::op::add{}, sum, z);
p.add_instruction(migraphx::op::sigmoid{}, triadd);
return p;
}
};
struct test_triadd_tanh : verify_program<test_triadd_tanh>
{
migraphx::program create_program() const
{
migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto z = p.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto sum = p.add_instruction(migraphx::op::add{}, x, y);
auto triadd = p.add_instruction(migraphx::op::add{}, sum, z);
p.add_instruction(migraphx::op::tanh{}, triadd);
return p;
}
};
struct test_sigmoid : verify_program<test_sigmoid>
{
migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment