Commit fe849702 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refactor of the layernorm code

parent 76547728
...@@ -215,28 +215,12 @@ __device__ __half2 block_reduce_half2( ...@@ -215,28 +215,12 @@ __device__ __half2 block_reduce_half2(
// m = x - mean(x) // m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12) // m / sqrt(mean(m ^ 2) + 1e-12)
__global__ void triadd_layernorm_kernel_half2( __device__ void layernorm_kernel_half2(__half2* in_data, __half2* in_data_reduce,
void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size) __half2* out, index_int batch_item_num, index_int block_size,
float rbatch_num)
{ {
__half2* input1 = reinterpret_cast<__half2*>(in1); auto rnum = __float2half2_rn(rbatch_num);
__half2* input2 = reinterpret_cast<__half2*>(in2);
__half2* input3 = reinterpret_cast<__half2*>(in3);
__half2* output = reinterpret_cast<__half2*>(data_out);
auto rnum = __float2half2_rn(1.0f / batch_item_num);
batch_item_num /= 2;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[]; extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2;
__half2* in_data = buffer2 + batch_item_num;
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
in_data[i] = __hadd2(__hadd2(input1[idx], input2[idx]), input3[idx]);
in_data_reduce[i] = in_data[i];
// in_data_reduce[i] = __hmul2(in_data[i], rnum);
}
auto m = auto m =
block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{}); block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
m = __hmul2(m, rnum); m = __hmul2(m, rnum);
...@@ -244,7 +228,6 @@ __global__ void triadd_layernorm_kernel_half2( ...@@ -244,7 +228,6 @@ __global__ void triadd_layernorm_kernel_half2(
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = __hsub2(in_data[i], m); in_data[i] = __hsub2(in_data[i], m);
// in_data_reduce[i] = __hmul2(__hmul2(in_data[i], in_data[i]), rnum);
in_data_reduce[i] = __hmul2(in_data[i], in_data[i]); in_data_reduce[i] = __hmul2(in_data[i], in_data[i]);
} }
...@@ -255,11 +238,36 @@ __global__ void triadd_layernorm_kernel_half2( ...@@ -255,11 +238,36 @@ __global__ void triadd_layernorm_kernel_half2(
auto r = __hadd2(m, eps); auto r = __hadd2(m, eps);
r = h2rsqrt(r); r = h2rsqrt(r);
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
out[idx] = __hmul2(in_data[i], r);
}
}
__global__ void triadd_layernorm_half2(
void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size)
{
__half2* input1 = reinterpret_cast<__half2*>(in1);
__half2* input2 = reinterpret_cast<__half2*>(in2);
__half2* input3 = reinterpret_cast<__half2*>(in3);
__half2* output = reinterpret_cast<__half2*>(data_out);
float rnum = 1.0f / batch_item_num;
batch_item_num /= 2;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2;
__half2* in_data = buffer2 + batch_item_num;
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
int idx = i + start; int idx = i + start;
output[idx] = __hmul2(in_data[i], r); in_data[i] = __hadd2(__hadd2(input1[idx], input2[idx]), input3[idx]);
in_data_reduce[i] = in_data[i];
} }
layernorm_kernel_half2(in_data, in_data_reduce, output, batch_item_num, block_size, rnum);
} }
template <class T> template <class T>
...@@ -281,105 +289,55 @@ block_reduce_half(T* buffer, index_int batch_item_num, index_int tid, index_int ...@@ -281,105 +289,55 @@ block_reduce_half(T* buffer, index_int batch_item_num, index_int tid, index_int
// m = x - mean(x) // m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12) // m / sqrt(mean(m ^ 2) + 1e-12)
__global__ void triadd_layernorm_kernel_half( __device__ void layernorm_kernel_half(__half* in_data, __half* in_data_reduce, __half* out,
void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size) index_int batch_item_num, index_int block_size, float rnum)
{ {
__half* input1 = reinterpret_cast<__half*>(in1);
__half* input2 = reinterpret_cast<__half*>(in2);
__half* input3 = reinterpret_cast<__half*>(in3);
__half* output = reinterpret_cast<__half*>(data_out);
extern MIGRAPHX_DEVICE_SHARED __half bufferh[];
__half* in_data_reduce = bufferh;
__half* in_data = bufferh + batch_item_num;
int start = blockIdx.x * batch_item_num;
auto rnum = 1.0f / batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
in_data[i] = __float2half(__half2float(input1[idx]) + __half2float(input2[idx]) +
__half2float(input3[idx]));
in_data_reduce[i] = __float2half(__half2float(in_data[i]) * __half2float(rnum));
}
auto m = block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size); auto m = block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size);
m *= rnum;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = __float2half(__half2float(in_data[i]) - __half2float(m)); in_data[i] = __float2half(__half2float(in_data[i]) - __half2float(m));
in_data_reduce[i] = in_data_reduce[i] = __float2half(__half2float(in_data[i]) * __half2float(in_data[i]));
__float2half(__half2float(in_data[i]) * __half2float(in_data[i]) * __half2float(rnum));
} }
m = __float2half( m = block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size);
__half2float(block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size)) + m *= rnum;
1.0e-12f); m += 1.0e-12f;
auto r = __float2half(rsqrt(__half2float(m))); auto r = __float2half(rsqrt(__half2float(m)));
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
int idx = i + start; int idx = i + start;
output[idx] = __float2half(__half2float(in_data[i]) * __half2float(r)); out[idx] = __float2half(__half2float(in_data[i]) * __half2float(r));
} }
} }
template <class T>
__device__ T block_reduce(T* buffer, index_int batch_item_num, index_int tid, index_int block_size)
{
__syncthreads();
for(index_int s = block_size; s > 0; s >>= 1)
{
if(tid < s and tid + s < batch_item_num)
{
buffer[tid] = buffer[tid] + buffer[tid + s];
}
__syncthreads();
}
return buffer[0];
}
// m = x - mean(x) // m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12) // m / sqrt(mean(m ^ 2) + 1e-12)
template <class T> __global__ void triadd_layernorm_half(
__global__ void triadd_layernorm_kernel(
void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size) void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size)
{ {
T* input1 = reinterpret_cast<T*>(in1); __half* input1 = reinterpret_cast<__half*>(in1);
T* input2 = reinterpret_cast<T*>(in2); __half* input2 = reinterpret_cast<__half*>(in2);
T* input3 = reinterpret_cast<T*>(in3); __half* input3 = reinterpret_cast<__half*>(in3);
T* output = reinterpret_cast<T*>(data_out); __half* output = reinterpret_cast<__half*>(data_out);
extern MIGRAPHX_DEVICE_SHARED T buffer[]; float rnum = 1.0f / batch_item_num;
T* in_data_reduce = buffer; extern MIGRAPHX_DEVICE_SHARED __half bufferh[];
T* in_data = buffer + batch_item_num; __half* in_data_reduce = bufferh;
__half* in_data = bufferh + batch_item_num;
int start = blockIdx.x * batch_item_num; int start = blockIdx.x * batch_item_num;
auto rnum = 1.0f / batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
int idx = i + start; int idx = i + start;
in_data[i] = input1[idx] + input2[idx] + input3[idx]; in_data[i] = __float2half(__half2float(input1[idx]) + __half2float(input2[idx]) +
in_data_reduce[i] = in_data[i]; __half2float(input3[idx]));
// in_data_reduce[i] = __half2float(in_data[i]) * rnum;
}
auto m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size);
m = m * rnum;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = in_data[i] - m;
in_data_reduce[i] = in_data[i] * in_data[i];
// in_data_reduce[i] = __half2float(in_data[i] * in_data[i]) * rnum;
} }
m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size);
m = m * rnum + 1.0e-12f;
auto r = rsqrt(m);
for(int i = threadIdx.x; i < batch_item_num; i += block_size) layernorm_kernel_half(in_data, in_data_reduce, output, batch_item_num, block_size, rnum);
{
int idx = i + start;
// output[idx] = __half2float(in_data[i]) * r;
output[idx] = in_data[i] * r;
}
} }
void triadd_layernorm(hipStream_t stream, void triadd_layernorm(hipStream_t stream,
...@@ -397,7 +355,7 @@ void triadd_layernorm(hipStream_t stream, ...@@ -397,7 +355,7 @@ void triadd_layernorm(hipStream_t stream,
int block_num = in_s.elements() / batch_item_num; int block_num = in_s.elements() / batch_item_num;
int shared_size = batch_item_num * 2 * in_s.type_size(); int shared_size = batch_item_num * 2 * in_s.type_size();
half2_block_size = half2_block_size / 4; half2_block_size = half2_block_size / 4;
triadd_layernorm_kernel_half2<<<block_num, half2_block_size, shared_size, stream>>>( triadd_layernorm_half2<<<block_num, half2_block_size, shared_size, stream>>>(
arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num, half2_block_size); arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num, half2_block_size);
} }
else else
...@@ -409,11 +367,11 @@ void triadd_layernorm(hipStream_t stream, ...@@ -409,11 +367,11 @@ void triadd_layernorm(hipStream_t stream,
} }
__global__ void __global__ void
layernorm_kernel_half2(void* in1, void* data_out, index_int batch_item_num, index_int block_size) layernorm_half2(void* in1, void* data_out, index_int batch_item_num, index_int block_size)
{ {
__half2* input1 = reinterpret_cast<__half2*>(in1); __half2* input1 = reinterpret_cast<__half2*>(in1);
__half2* output = reinterpret_cast<__half2*>(data_out); __half2* output = reinterpret_cast<__half2*>(data_out);
auto rnum = __float2half2_rn(1.0f / batch_item_num); float rnum = 1.0f / batch_item_num;
batch_item_num /= 2; batch_item_num /= 2;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[]; extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2; __half2* in_data_reduce = buffer2;
...@@ -427,28 +385,28 @@ layernorm_kernel_half2(void* in1, void* data_out, index_int batch_item_num, inde ...@@ -427,28 +385,28 @@ layernorm_kernel_half2(void* in1, void* data_out, index_int batch_item_num, inde
in_data_reduce[i] = in_data[i]; in_data_reduce[i] = in_data[i];
} }
auto m = layernorm_kernel_half2(in_data, in_data_reduce, output, batch_item_num, block_size, rnum);
block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{}); }
m = __hmul2(m, rnum);
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __hsub2(in_data[i], m);
in_data_reduce[i] = __hmul2(in_data[i], in_data[i]);
}
m = block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
m = __hmul2(m, rnum);
auto eps = __float2half2_rn(1.0e-12f); __global__ void
auto r = __hadd2(m, eps); layernorm_half(void* in1, void* data_out, index_int batch_item_num, index_int block_size)
r = h2rsqrt(r); {
__half* input1 = reinterpret_cast<__half*>(in1);
__half* output = reinterpret_cast<__half*>(data_out);
float rnum = 1.0f / batch_item_num;
extern MIGRAPHX_DEVICE_SHARED __half buffer3[];
__half* in_data_reduce = buffer3;
__half* in_data = buffer3 + batch_item_num;
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
int idx = i + start; int idx = i + start;
output[idx] = __hmul2(in_data[i], r); in_data[i] = input1[idx];
in_data_reduce[i] = in_data[i];
} }
layernorm_kernel_half(in_data, in_data_reduce, output, batch_item_num, block_size, rnum);
} }
void layernorm(hipStream_t stream, const argument& result, const argument& arg1) void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
...@@ -462,7 +420,7 @@ void layernorm(hipStream_t stream, const argument& result, const argument& arg1) ...@@ -462,7 +420,7 @@ void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
int block_num = in_s.elements() / batch_item_num; int block_num = in_s.elements() / batch_item_num;
int shared_size = batch_item_num * 2 * in_s.type_size(); int shared_size = batch_item_num * 2 * in_s.type_size();
half2_block_size = half2_block_size / 4; half2_block_size = half2_block_size / 4;
layernorm_kernel_half2<<<block_num, half2_block_size, shared_size, stream>>>( layernorm_half2<<<block_num, half2_block_size, shared_size, stream>>>(
arg1.data(), result.data(), batch_item_num, half2_block_size); arg1.data(), result.data(), batch_item_num, half2_block_size);
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment