Commit a5181cd0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

layernorm kernel optimization

parent d5c2538c
......@@ -2,6 +2,8 @@
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/pow.hpp>
#include <migraphx/gpu/device/fast_div.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -94,9 +96,9 @@ __device__ void layernorm(index_int i,
const bool in_range = idx.local < relements_v;
auto mean = [&](auto z) {
auto m = auto_block_reduce<MaxBlockSize>(
idx, sum{}, value_type(0), relements_v, [=](auto) { return z; }) /
value_type(relements);
auto m = auto_block_reduce<MaxBlockSize>(idx, sum{}, value_type(0), relements_v, [=](auto) {
return z / value_type(relements);
});
#if MIGRAPHX_WORKAROUND_NAVI_DPP_SYNC
__builtin_amdgcn_s_barrier();
#endif
......@@ -158,7 +160,7 @@ void layernorm_impl(hipStream_t stream,
const Arguments&... args)
{
hip_visit_all(result, args...)([&](auto output, auto... inputs) {
const std::size_t max_block_size = 256;
const std::size_t max_block_size = 128;
const std::size_t block_size = compute_block_size(relements, max_block_size);
const std::size_t block_size_div = encode_divisor(block_size);
assert(relements <= block_size);
......@@ -200,14 +202,230 @@ auto layernorm_fusion(hipStream_t stream,
};
}
struct half2_sum
{
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return __hadd2(x, y); }
};
// in_data is in shared memory
template <class Op>
__device__ __half2 block_reduce_half2(
__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{
__syncthreads();
for(index_int s = block_size; s > 0; s >>= 1)
{
if(tid < s and tid + s < batch_item_num)
{
buffer[tid] = op(buffer[tid], buffer[tid + s]);
}
__syncthreads();
}
auto lows2 = __low2half2(buffer[0]);
auto highs2 = __high2half2(buffer[0]);
return op(lows2, highs2);
}
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
__global__ void triadd_layernorm_kernel_half2(
void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size)
{
__half2* input1 = reinterpret_cast<__half2*>(in1);
__half2* input2 = reinterpret_cast<__half2*>(in2);
__half2* input3 = reinterpret_cast<__half2*>(in3);
__half2* output = reinterpret_cast<__half2*>(data_out);
batch_item_num /= 2;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2;
__half2* in_data = buffer2 + batch_item_num;
int start = blockIdx.x * batch_item_num;
auto rnum = __float2half2_rn(1.0f / batch_item_num);
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
in_data[i] = __hadd2(__hadd2(input1[idx], input2[idx]), input3[idx]);
in_data_reduce[i] = __hmul2(in_data[i], rnum);
}
auto m =
block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __hsub2(in_data[i], m);
in_data_reduce[i] = __hmul2(__hmul2(in_data[i], in_data[i]), rnum);
}
m = block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
auto eps = __float2half2_rn(1.0e-12f);
auto r = __hadd2(m, eps);
r = h2rsqrt(r);
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
output[idx] = __hmul2(in_data[i], r);
}
}
template <class T>
__device__ T
block_reduce_half(T* buffer, index_int batch_item_num, index_int tid, index_int block_size)
{
__syncthreads();
for(index_int s = block_size; s > 0; s >>= 1)
{
if(tid < s and tid + s < batch_item_num)
{
buffer[tid] = __float2half(__half2float(buffer[tid]) + __half2float(buffer[tid + s]));
}
__syncthreads();
}
return buffer[0];
}
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
__global__ void triadd_layernorm_kernel_half(
void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size)
{
__half* input1 = reinterpret_cast<__half*>(in1);
__half* input2 = reinterpret_cast<__half*>(in2);
__half* input3 = reinterpret_cast<__half*>(in3);
__half* output = reinterpret_cast<__half*>(data_out);
extern MIGRAPHX_DEVICE_SHARED __half bufferh[];
__half* in_data_reduce = bufferh;
__half* in_data = bufferh + batch_item_num;
int start = blockIdx.x * batch_item_num;
auto rnum = 1.0f / batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
in_data[i] = __float2half(__half2float(input1[idx]) + __half2float(input2[idx]) +
__half2float(input3[idx]));
in_data_reduce[i] = __float2half(__half2float(in_data[i]) * __half2float(rnum));
}
auto m = block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size);
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __float2half(__half2float(in_data[i]) - __half2float(m));
in_data_reduce[i] =
__float2half(__half2float(in_data[i]) * __half2float(in_data[i]) * __half2float(rnum));
}
m = __float2half(
__half2float(block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size)) +
1.0e-12f);
auto r = __float2half(rsqrt(__half2float(m)));
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
output[idx] = __float2half(__half2float(in_data[i]) * __half2float(r));
}
}
template <class T>
__device__ T block_reduce(T* buffer, index_int batch_item_num, index_int tid, index_int block_size)
{
__syncthreads();
for(index_int s = block_size; s > 0; s >>= 1)
{
if(tid < s and tid + s < batch_item_num)
{
buffer[tid] = buffer[tid] + buffer[tid + s];
}
__syncthreads();
}
return buffer[0];
}
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
template <class T>
__global__ void triadd_layernorm_kernel(
void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size)
{
T* input1 = reinterpret_cast<T*>(in1);
T* input2 = reinterpret_cast<T*>(in2);
T* input3 = reinterpret_cast<T*>(in3);
T* output = reinterpret_cast<T*>(data_out);
extern MIGRAPHX_DEVICE_SHARED T buffer[];
T* in_data_reduce = buffer;
T* in_data = buffer + batch_item_num;
int start = blockIdx.x * batch_item_num;
auto rnum = 1.0f / batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
in_data[i] = input1[idx] + input2[idx] + input3[idx];
in_data_reduce[i] = in_data[i] * rnum;
}
auto m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size);
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = in_data[i] - m;
in_data_reduce[i] = in_data[i] * in_data[i] * rnum;
}
m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size) + 1.0e-12f;
auto r = rsqrt(m);
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
output[idx] = in_data[i] * r;
}
}
void triadd_layernorm(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
layernorm_fusion(stream, result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return x + y + z; }, [](auto x, auto& y, auto...) { y = x; });
auto in_s = arg1.get_shape();
auto type = in_s.type();
auto batch_item_num = in_s.lens().back();
if(type == shape::half_type and (batch_item_num % 2) == 0)
{
auto half2_block_size = compute_block_size(batch_item_num, 1024);
int block_num = in_s.elements() / batch_item_num;
int shared_size = batch_item_num * 2 * in_s.type_size();
half2_block_size = half2_block_size / 4;
triadd_layernorm_kernel_half2<<<block_num, half2_block_size, shared_size, stream>>>(
arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num, half2_block_size);
}
// if(type == shape::half_type and (batch_item_num % 2) == 0)
// {
// auto reduce_block_size = compute_block_size(batch_item_num, 1024);
// int block_num = in_s.elements() / batch_item_num;
// int shared_size = batch_item_num * 2 * in_s.type_size();
// reduce_block_size = reduce_block_size / 2;
// triadd_layernorm_kernel_half<<<block_num, reduce_block_size, shared_size, stream>>>(
// arg1.data(),
// arg2.data(),
// arg3.data(),
// result.data(),
// batch_item_num,
// reduce_block_size);
// }
else
{
layernorm_fusion(stream, result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return x + y + z; },
[](auto x, auto& y, auto...) { y = x; });
}
}
void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment