Commit 8e68f6ca authored by Paul's avatar Paul
Browse files

Add fused mul add

parent 16864eef
...@@ -16,6 +16,7 @@ add_library(migraphx_device ...@@ -16,6 +16,7 @@ add_library(migraphx_device
device/argmin.cpp device/argmin.cpp
device/max.cpp device/max.cpp
device/min.cpp device/min.cpp
device/mul_add.cpp
device/exp.cpp device/exp.cpp
device/erf.cpp device/erf.cpp
device/log.cpp device/log.cpp
......
...@@ -118,10 +118,66 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg ...@@ -118,10 +118,66 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
}); });
} }
template <class F, class... Arguments>
void nary_double_broadcast_vec_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{
assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t bdim_vec_len = bdim_len / vec_size;
hip_vec_visit_all<vec_size>(result, barg1, barg2, args...)(
[&](auto output, auto binput1, auto binput2, auto... inputs) {
using type = typename decltype(output)::value_type;
const std::size_t nelements = output.size() / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED type buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = binput1.data()[i];
}
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i + bdim_vec_len] = binput2.data()[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < nelements; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b1 = bp[bidx];
auto b2 = bp[bidx+bdim_vec_len];
auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(inputs.data()[i][j]..., b2, b1);
}
output.data()[i] = out;
}
});
});
}
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_double_broadcast_impl( void nary_double_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args) hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{ {
assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape(); const auto& b_shape = barg1.get_shape();
auto bdim = auto bdim =
...@@ -148,7 +204,7 @@ void nary_double_broadcast_impl( ...@@ -148,7 +204,7 @@ void nary_double_broadcast_impl(
} }
for(size_t i = idx.local; i < bdim_len; i += nlocal) for(size_t i = idx.local; i < bdim_len; i += nlocal)
{ {
buffer[i + bdim_len] = binput2.data()[i + bdim_len]; buffer[i + bdim_len] = binput2.data()[i];
} }
__syncthreads(); __syncthreads();
// Process the data // Process the data
...@@ -157,7 +213,7 @@ void nary_double_broadcast_impl( ...@@ -157,7 +213,7 @@ void nary_double_broadcast_impl(
auto bidx = (i % bdim_next_stride) / bdim_stride; auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b1 = buffer[bidx]; auto b1 = buffer[bidx];
auto b2 = buffer[bidx + bdim_len]; auto b2 = buffer[bidx + bdim_len];
output.data()[i] = f(inputs.data()[i]..., b1, b2); output.data()[i] = f(inputs.data()[i]..., b2, b1);
} }
}); });
}); });
...@@ -222,7 +278,7 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args) ...@@ -222,7 +278,7 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
} }
template <class... Arguments> template <class... Arguments>
bool broadcastable(bool& divisible_by_4, argument result, argument barg, Arguments... args) bool broadcastable(bool& divisible_by_4, std::size_t max_size, argument result, argument barg, Arguments... args)
{ {
divisible_by_4 = false; divisible_by_4 = false;
auto bshape = barg.get_shape(); auto bshape = barg.get_shape();
...@@ -240,7 +296,7 @@ bool broadcastable(bool& divisible_by_4, argument result, argument barg, Argumen ...@@ -240,7 +296,7 @@ bool broadcastable(bool& divisible_by_4, argument result, argument barg, Argumen
auto b_len = result.get_shape().lens()[b_idx]; auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx]; auto b_stride = result.get_shape().strides()[b_idx];
assert(bshape.lens()[b_idx] == b_len); assert(bshape.lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero)) if(b_len <= max_size and std::none_of(std::next(b_it), strides.end(), not_zero))
{ {
divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
...@@ -251,7 +307,7 @@ bool broadcastable(bool& divisible_by_4, argument result, argument barg, Argumen ...@@ -251,7 +307,7 @@ bool broadcastable(bool& divisible_by_4, argument result, argument barg, Argumen
return false; return false;
} }
inline bool broadcastable(bool& divisible_by_4, argument, argument) inline bool broadcastable(bool& divisible_by_4, std::size_t, argument, argument)
{ {
divisible_by_4 = false; divisible_by_4 = false;
return false; return false;
...@@ -274,7 +330,7 @@ inline auto nary(hipStream_t stream, argument result, argument arg, argument bar ...@@ -274,7 +330,7 @@ inline auto nary(hipStream_t stream, argument result, argument arg, argument bar
{ {
return [=](auto f) { return [=](auto f) {
bool divisible_by_4 = false; bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, result, barg, arg)) if(broadcastable(divisible_by_4, 2048, result, barg, arg))
{ {
if(divisible_by_4) if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg, arg); nary_broadcast_vec_impl(stream, f, result, barg, arg);
...@@ -293,9 +349,24 @@ auto nary(hipStream_t stream, argument result, Arguments... args) ...@@ -293,9 +349,24 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) {
auto barg1 = back_args(args...); auto barg1 = back_args(args...);
bool fallback = pop_back_args(args...)([&](auto&&... args2) { bool fallback1 = pop_back_args(args...)([&](auto&&... args2) {
auto barg2 = back_args(args2...);
bool fallback2 = barg2.get_shape() == barg1.get_shape() and barg2.get_shape().broadcasted() and pop_back_args(args2...)([&](auto&&... args3) {
bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 1024, result, barg2, args3...))
{
if(divisible_by_4)
nary_double_broadcast_vec_impl(stream, f, result, barg1, barg2, args3...);
else
nary_double_broadcast_impl(stream, f, result, barg1, barg2, args3...);
return false;
}
return true;
});
if (not fallback2)
return false;
bool divisible_by_4 = false; bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, result, barg1, args2...)) if(broadcastable(divisible_by_4, 2048, result, barg1, args2...))
{ {
if(divisible_by_4) if(divisible_by_4)
nary_broadcast_vec_impl(stream, f, result, barg1, args2...); nary_broadcast_vec_impl(stream, f, result, barg1, args2...);
...@@ -305,7 +376,7 @@ auto nary(hipStream_t stream, argument result, Arguments... args) ...@@ -305,7 +376,7 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
} }
return true; return true;
}); });
if(fallback) if(fallback1)
nary_impl(stream, f, result, args...); nary_impl(stream, f, result, args...);
}; };
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_relu.hpp> #include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp> #include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -200,21 +201,42 @@ struct hip_add_relu ...@@ -200,21 +201,42 @@ struct hip_add_relu
} }
}; };
struct hip_mul_add
{
std::string name() const { return "hip::mul_add"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4);
return inputs.front();
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::mul_add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
void move_broadcasted_back(std::vector<instruction_ref>& args) void move_broadcasted_back(std::vector<instruction_ref>& args)
{ {
// Ensure the last arguments is the broadcasted one // Ensure the last arguments is the broadcasted one
auto last = std::prev(args.end());
auto it = std::find_if( auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); }); args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != args.end()) if(it != last)
std::swap(*it, *std::prev(args.end(), 2)); std::swap(*it, *std::prev(last));
} }
void move_standard_front(std::vector<instruction_ref>& args) void move_standard_front(std::vector<instruction_ref>& args)
{ {
// Ensure the first arguments is the standard one // Ensure the first arguments is the standard one
auto last = std::prev(args.end());
auto it = std::find_if( auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); }); args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
if(it != args.end()) if(it != last)
std::swap(*it, args.front()); std::swap(*it, args.front());
} }
...@@ -278,6 +300,32 @@ struct find_triadd ...@@ -278,6 +300,32 @@ struct find_triadd
} }
}; };
struct find_mul_add
{
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::mul").bind("mul"),
match::any().bind("b")));
}
void apply(program& p, match::matcher_result r) const
{
auto mul_ins = r.instructions["mul"];
auto b_ins = r.instructions["b"];
auto ins = r.result;
auto args = mul_ins->inputs();
assert(mul_ins != b_ins);
move_standard_front(args);
move_broadcasted_back(args);
args.insert(std::prev(args.end()), b_ins);
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_mul_add{}, args);
}
};
struct miopen_conv_bias struct miopen_conv_bias
{ {
op::convolution op; op::convolution op;
...@@ -433,7 +481,8 @@ void fuse_ops::apply(program& p) const ...@@ -433,7 +481,8 @@ void fuse_ops::apply(program& p) const
match::find_matches(p, match::find_matches(p,
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_relu{} find_add_relu{},
find_mul_add{}
); );
// clang-format on // clang-format on
} }
......
...@@ -490,6 +490,24 @@ struct test_triadd2 : verify_program<test_triadd2> ...@@ -490,6 +490,24 @@ struct test_triadd2 : verify_program<test_triadd2>
} }
}; };
struct test_mul_add : verify_program<test_mul_add>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape bs{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto a = p.add_parameter("a", bs);
auto b = p.add_parameter("b", bs);
auto ab = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, a);
auto bb = p.add_instruction(migraphx::op::broadcast{1, s.lens()}, b);
auto mul = p.add_instruction(migraphx::op::mul{}, x, ab);
p.add_instruction(migraphx::op::add{}, mul, bb);
return p;
}
};
struct test_add_broadcast : verify_program<test_add_broadcast> struct test_add_broadcast : verify_program<test_add_broadcast>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment