Unverified Commit 1bfb147d authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fuse skip layernorm (#683)



* Unify the vectorized and non-vectorized path

* Formatting

* Make fusion easily extendable

* Add skip layernorm fusion

* Formatting

* Call correct layernorm function

* Fix compile errors

* Add DCE

* Add test for skip layernorm

* Formatting

* Remove unused typedef

* Formatting

* Fix tidy issues

* Formatting
Co-authored-by: default avatarShucai Xiao <shucai.xiao@amd.com>
parent a66761ea
File mode changed from 100644 to 100755
......@@ -22,19 +22,101 @@ struct vector_type<vec<T, N>>
template <class T>
using vector_type_t = typename vector_type<T>::type;
template <class T>
struct vector_size : std::integral_constant<index_int, 1>
{
};
template <class T, index_int N>
struct vector_size<vec<T, N>> : std::integral_constant<index_int, N>
{
};
template <class T, class F>
__device__ auto vec_transform(T x, F f)
{
return f(x);
}
template <class T, index_int N, class F>
__device__ auto vec_transform(vec<T, N> x, F f)
{
vec<T, N> y = x;
// cppcheck-suppress useStlAlgorithm
for(index_int k = 0; k < N; k++)
y[k] = f(x[k]);
return y;
}
template <class T, class U, class Op>
__device__ auto vec_reduce(T x, U, Op)
{
return x;
}
template <class T, index_int N, class U, class Op>
__device__ auto vec_reduce(vec<T, N> x, U init, Op op)
{
T r = init;
for(index_int k = 0; k < N; k++)
r = op(r, x[k]);
return r;
}
template <index_int N, class Op, class T, class F>
__device__ auto auto_block_reduce(index idx, Op op, T init, index_int n, F f)
{
auto r = block_reduce<N>(idx, op, init, n, f);
return vec_reduce(r, 0, op);
}
template <index_int MaxBlockSize, class Input, class Output>
__device__ void layernorm(index_int i,
index idx,
std::size_t block_size_div,
index_int relements,
Input input,
Output output)
{
using value_type = decltype(input(idx.local));
const auto relements_v = relements / vector_size<value_type>{};
const auto out_idx = fast_div(i, block_size_div);
const auto base_idx = out_idx * relements_v;
const auto input_idx = base_idx + idx.local;
const bool in_range = idx.local < relements_v;
auto mean = [&](auto z) {
return auto_block_reduce<MaxBlockSize>(
idx, sum{}, value_type(0), relements_v, [=](auto) { return z; }) /
value_type(relements);
};
// m = x - mean(x)
value_type x = in_range ? input(input_idx) : 0;
value_type m = x - mean(x);
// mean(m ^ 2) + 1e-12
value_type r = mean(m * m) + value_type(1e-12);
// m * rsqrt(mean(m ^ 2) + 1e-12)
if(in_range)
output(input_idx, m * vec_transform(r, &rsqrt));
}
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
template <index_int N>
template <index_int N, class Input, class Output, class... Arguments>
void layernorm_vec_impl(hipStream_t stream,
const argument& result,
const argument& arg1,
index_int nelements,
index_int relements)
index_int relements,
Input in,
Output out,
const argument& result,
const Arguments&... args)
{
hip_vec_visit_all<N>(result, arg1)([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
hip_vec_visit_all<N>(result, args...)([&](auto output, auto... inputs) {
const auto relements_v = relements / N;
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements_v, max_block_size);
......@@ -42,96 +124,85 @@ void layernorm_vec_impl(hipStream_t stream,
assert(relements_v <= block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = fast_div(i, block_size_div);
const auto base_idx = out_idx * relements_v;
const auto input_idx = base_idx + idx.local;
const bool in_range = idx.local < relements_v;
auto mean = [&](auto z) {
auto psum = block_reduce<max_block_size>(
idx, sum{}, value_type(0), relements_v, [=](auto) { return z; });
vector_type_t<value_type> sum = 0;
for(index_int k = 0; k < N; k++)
sum += psum[k];
return sum / relements;
};
// m = x - mean(x)
value_type x = in_range ? input.data()[input_idx] : 0;
value_type m = x - mean(x);
// mean(m ^ 2) + 1e-12
value_type r = mean(m * m) + value_type(1e-12);
// rsqrt(mean(m ^ 2) + 1e-12)
value_type d = 0;
for(index_int k = 0; k < N; k++)
d[k] = ::rsqrt(r[k]);
// m * rsqrt(mean(m ^ 2) + 1e-12)
if(in_range)
output.data()[input_idx] = m * d;
layernorm<max_block_size>(
i,
idx,
block_size_div,
relements,
[&](auto input_idx) { return in(inputs.data()[input_idx]...); },
[&](auto input_idx, auto x) {
out(x, output.data()[input_idx], inputs.data()[input_idx]...);
});
});
});
}
template <class Input, class Output, class... Arguments>
void layernorm_impl(hipStream_t stream,
const argument& result,
const argument& arg1,
index_int nelements,
index_int relements)
index_int relements,
Input in,
Output out,
const argument& result,
const Arguments&... args)
{
hip_visit_all(result, arg1)([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
hip_visit_all(result, args...)([&](auto output, auto... inputs) {
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size);
const std::size_t block_size_div = encode_divisor(block_size);
assert(relements <= block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = fast_div(i, block_size_div);
const auto base_idx = out_idx * relements;
const auto input_idx = base_idx + idx.local;
const bool in_range = idx.local < relements;
auto mean = [&](auto z) {
return block_reduce<max_block_size>(idx,
sum{},
value_type(0),
relements,
[=](auto) { return in_range ? z : 0; }) /
relements;
};
// m = x - mean(x)
value_type x = in_range ? input.data()[input_idx] : 0;
value_type m = x - mean(x);
// mean(m ^ 2) + 1e-12
value_type r = mean(m * m) + 1e-12;
// m * rsqrt(mean(m ^ 2) + 1e-12)
if(in_range)
output.data()[input_idx] = m * ::rsqrt(r);
layernorm<max_block_size>(
i,
idx,
block_size_div,
relements,
[&](auto input_idx) { return in(inputs.data()[input_idx]...); },
[&](auto input_idx, auto x) {
out(x, output.data()[input_idx], inputs.data()[input_idx]...);
});
});
});
}
template <class... Arguments>
auto layernorm_fusion(hipStream_t stream,
const argument& result,
const argument& arg1,
const Arguments&... args)
{
return [=](auto input, auto output) {
auto relements = arg1.get_shape().lens().back();
auto nelements = result.get_shape().elements() / relements;
auto output_shape = result.get_shape();
auto reduce_output_lens(output_shape.lens());
reduce_output_lens.back() = 1;
if((relements % 4) == 0)
layernorm_vec_impl<4>(
stream, nelements, relements, input, output, result, arg1, args...);
else if(relements < 256)
layernorm_impl(stream, nelements, relements, input, output, result, arg1, args...);
else
MIGRAPHX_THROW("No kernel for layernorm");
};
}
void triadd_layernorm(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
layernorm_fusion(stream, result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return x + y + z; }, [](auto x, auto& y, auto...) { y = x; });
}
void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
{
auto relements = arg1.get_shape().lens().back();
auto nelements = result.get_shape().elements() / relements;
auto output_shape = result.get_shape();
auto reduce_output_lens(output_shape.lens());
reduce_output_lens.back() = 1;
if((relements % 4) == 0)
layernorm_vec_impl<4>(stream, result, arg1, nelements, relements);
else if(relements < 256)
layernorm_impl(stream, result, arg1, nelements, relements);
else
MIGRAPHX_THROW("No kernel for layernorm");
layernorm_fusion(stream, result, arg1)([](auto x) { return x; },
[](auto x, auto& y, auto) { y = x; });
}
} // namespace device
......
......@@ -238,6 +238,13 @@ struct hip_layernorm : unary_device<hip_layernorm, &device::layernorm>
};
MIGRAPHX_REGISTER_OP(hip_layernorm)
struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm>
{
// Empty finalize to skip dimension reduction
void finalize(context&, const shape&, const std::vector<shape>&) {}
};
MIGRAPHX_REGISTER_OP(hip_triadd_layernorm)
struct hip_gelu : unary_device<hip_gelu, &device::gelu>
{
};
......@@ -341,6 +348,22 @@ struct find_layernorm
}
};
struct find_triadd_layernorm
{
auto matcher() const
{
return match::name("gpu::layernorm")(match::arg(0)(match::name("gpu::triadd")(
match::used_once(), match::all_of[match::inputs()](match::standard_shape()))));
}
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto triadd = ins->inputs().front();
p.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs());
}
};
struct find_gelu
{
......@@ -827,7 +850,8 @@ void fuse_ops::apply(module& p) const
find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{});
match::find_matches(p, find_gemm_add{}, find_commutative_broadcast{});
run_passes(p, {dead_code_elimination{}});
match::find_matches(p, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{});
}
} // namespace gpu
......
......@@ -12,6 +12,12 @@ namespace device {
void layernorm(hipStream_t stream, const argument& result, const argument& arg1);
void triadd_layernorm(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -4,33 +4,32 @@
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
migraphx::instruction_ref add_layernorm(migraphx::program& p, std::vector<size_t> dims)
migraphx::instruction_ref
add_layernorm(migraphx::module& m, migraphx::instruction_ref x, std::vector<size_t> dims)
{
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto scale =
mm->add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
m.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto bias =
mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto epsilon = mm->add_literal(1e-12f);
auto exponent = mm->add_literal(2.0f);
m.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}});
auto epsilon = m.add_literal(1e-12f);
auto exponent = m.add_literal(2.0f);
auto mean = mm->add_instruction(migraphx::op::reduce_mean({2}), x);
auto mean_mbcast = mm->add_instruction(migraphx::op::multibroadcast{{dims}}, mean);
auto sub = mm->add_instruction(migraphx::op::sub{}, x, mean_mbcast);
auto exponent_mbcast = mm->add_instruction(migraphx::op::multibroadcast{{dims}}, exponent);
auto pow = mm->add_instruction(migraphx::op::pow{}, sub, exponent_mbcast);
auto var = mm->add_instruction(migraphx::op::reduce_mean({2}), pow);
auto mean = m.add_instruction(migraphx::op::reduce_mean({2}), x);
auto mean_mbcast = m.add_instruction(migraphx::op::multibroadcast{{dims}}, mean);
auto sub = m.add_instruction(migraphx::op::sub{}, x, mean_mbcast);
auto exponent_mbcast = m.add_instruction(migraphx::op::multibroadcast{{dims}}, exponent);
auto pow = m.add_instruction(migraphx::op::pow{}, sub, exponent_mbcast);
auto var = m.add_instruction(migraphx::op::reduce_mean({2}), pow);
auto epsilon_mbcast =
mm->add_instruction(migraphx::op::multibroadcast{{1, dims.at(1), 1}}, epsilon);
auto add_epsilon = mm->add_instruction(migraphx::op::add{}, var, epsilon_mbcast);
auto sqrt = mm->add_instruction(migraphx::op::sqrt{}, add_epsilon);
auto sqrt_mbcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, sqrt);
auto div = mm->add_instruction(migraphx::op::div{}, sub, sqrt_mbcast);
auto scale_mbcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, scale);
auto mul = mm->add_instruction(migraphx::op::mul{}, scale_mbcast, div);
auto bias_mbcast = mm->add_instruction(migraphx::op::multibroadcast{dims}, bias);
return mm->add_instruction(migraphx::op::add{}, mul, bias_mbcast);
m.add_instruction(migraphx::op::multibroadcast{{1, dims.at(1), 1}}, epsilon);
auto add_epsilon = m.add_instruction(migraphx::op::add{}, var, epsilon_mbcast);
auto sqrt = m.add_instruction(migraphx::op::sqrt{}, add_epsilon);
auto sqrt_mbcast = m.add_instruction(migraphx::op::multibroadcast{dims}, sqrt);
auto div = m.add_instruction(migraphx::op::div{}, sub, sqrt_mbcast);
auto scale_mbcast = m.add_instruction(migraphx::op::multibroadcast{dims}, scale);
auto mul = m.add_instruction(migraphx::op::mul{}, scale_mbcast, div);
auto bias_mbcast = m.add_instruction(migraphx::op::multibroadcast{dims}, bias);
return m.add_instruction(migraphx::op::add{}, mul, bias_mbcast);
}
struct test_layernorm : verify_program<test_layernorm>
......@@ -38,7 +37,10 @@ struct test_layernorm : verify_program<test_layernorm>
migraphx::program create_program() const
{
migraphx::program p;
add_layernorm(p, {1, 1, 5});
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 1, 5};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
add_layernorm(*mm, x, dims);
return p;
}
};
......@@ -48,7 +50,27 @@ struct test_layernorm2 : verify_program<test_layernorm2>
migraphx::program create_program() const
{
migraphx::program p;
add_layernorm(p, {1, 4, 24});
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 4, 24};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
add_layernorm(*mm, x, dims);
return p;
}
};
struct test_layernorm_triadd : verify_program<test_layernorm_triadd>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 4, 24};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims});
auto add1 = mm->add_instruction(migraphx::op::add{}, x, y);
auto add2 = mm->add_instruction(migraphx::op::add{}, add1, z);
add_layernorm(*mm, add2, dims);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment