"profiler/vscode:/vscode.git/clone" did not exist on "5683ea4ed8f9f9e852dd92bdecbbf4cbd7cf45e5"
Commit fe849702 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refactor of the layernorm code

parent 76547728
......@@ -215,28 +215,12 @@ __device__ __half2 block_reduce_half2(
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
__global__ void triadd_layernorm_kernel_half2(
void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size)
__device__ void layernorm_kernel_half2(__half2* in_data, __half2* in_data_reduce,
__half2* out, index_int batch_item_num, index_int block_size,
float rbatch_num)
{
__half2* input1 = reinterpret_cast<__half2*>(in1);
__half2* input2 = reinterpret_cast<__half2*>(in2);
__half2* input3 = reinterpret_cast<__half2*>(in3);
__half2* output = reinterpret_cast<__half2*>(data_out);
auto rnum = __float2half2_rn(1.0f / batch_item_num);
batch_item_num /= 2;
auto rnum = __float2half2_rn(rbatch_num);
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2;
__half2* in_data = buffer2 + batch_item_num;
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
in_data[i] = __hadd2(__hadd2(input1[idx], input2[idx]), input3[idx]);
in_data_reduce[i] = in_data[i];
// in_data_reduce[i] = __hmul2(in_data[i], rnum);
}
auto m =
block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
m = __hmul2(m, rnum);
......@@ -244,7 +228,6 @@ __global__ void triadd_layernorm_kernel_half2(
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __hsub2(in_data[i], m);
// in_data_reduce[i] = __hmul2(__hmul2(in_data[i], in_data[i]), rnum);
in_data_reduce[i] = __hmul2(in_data[i], in_data[i]);
}
......@@ -255,11 +238,36 @@ __global__ void triadd_layernorm_kernel_half2(
auto r = __hadd2(m, eps);
r = h2rsqrt(r);
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
output[idx] = __hmul2(in_data[i], r);
out[idx] = __hmul2(in_data[i], r);
}
}
__global__ void triadd_layernorm_half2(
void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size)
{
__half2* input1 = reinterpret_cast<__half2*>(in1);
__half2* input2 = reinterpret_cast<__half2*>(in2);
__half2* input3 = reinterpret_cast<__half2*>(in3);
__half2* output = reinterpret_cast<__half2*>(data_out);
float rnum = 1.0f / batch_item_num;
batch_item_num /= 2;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2;
__half2* in_data = buffer2 + batch_item_num;
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
in_data[i] = __hadd2(__hadd2(input1[idx], input2[idx]), input3[idx]);
in_data_reduce[i] = in_data[i];
}
layernorm_kernel_half2(in_data, in_data_reduce, output, batch_item_num, block_size, rnum);
}
template <class T>
......@@ -281,105 +289,55 @@ block_reduce_half(T* buffer, index_int batch_item_num, index_int tid, index_int
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
__global__ void triadd_layernorm_kernel_half(
void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size)
__device__ void layernorm_kernel_half(__half* in_data, __half* in_data_reduce, __half* out,
index_int batch_item_num, index_int block_size, float rnum)
{
__half* input1 = reinterpret_cast<__half*>(in1);
__half* input2 = reinterpret_cast<__half*>(in2);
__half* input3 = reinterpret_cast<__half*>(in3);
__half* output = reinterpret_cast<__half*>(data_out);
extern MIGRAPHX_DEVICE_SHARED __half bufferh[];
__half* in_data_reduce = bufferh;
__half* in_data = bufferh + batch_item_num;
int start = blockIdx.x * batch_item_num;
auto rnum = 1.0f / batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
in_data[i] = __float2half(__half2float(input1[idx]) + __half2float(input2[idx]) +
__half2float(input3[idx]));
in_data_reduce[i] = __float2half(__half2float(in_data[i]) * __half2float(rnum));
}
auto m = block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size);
m *= rnum;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __float2half(__half2float(in_data[i]) - __half2float(m));
in_data_reduce[i] =
__float2half(__half2float(in_data[i]) * __half2float(in_data[i]) * __half2float(rnum));
in_data_reduce[i] = __float2half(__half2float(in_data[i]) * __half2float(in_data[i]));
}
m = __float2half(
__half2float(block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size)) +
1.0e-12f);
m = block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size);
m *= rnum;
m += 1.0e-12f;
auto r = __float2half(rsqrt(__half2float(m)));
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
output[idx] = __float2half(__half2float(in_data[i]) * __half2float(r));
out[idx] = __float2half(__half2float(in_data[i]) * __half2float(r));
}
}
template <class T>
__device__ T block_reduce(T* buffer, index_int batch_item_num, index_int tid, index_int block_size)
{
__syncthreads();
for(index_int s = block_size; s > 0; s >>= 1)
{
if(tid < s and tid + s < batch_item_num)
{
buffer[tid] = buffer[tid] + buffer[tid + s];
}
__syncthreads();
}
return buffer[0];
}
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
template <class T>
__global__ void triadd_layernorm_kernel(
__global__ void triadd_layernorm_half(
void* in1, void* in2, void* in3, void* data_out, index_int batch_item_num, index_int block_size)
{
T* input1 = reinterpret_cast<T*>(in1);
T* input2 = reinterpret_cast<T*>(in2);
T* input3 = reinterpret_cast<T*>(in3);
T* output = reinterpret_cast<T*>(data_out);
extern MIGRAPHX_DEVICE_SHARED T buffer[];
T* in_data_reduce = buffer;
T* in_data = buffer + batch_item_num;
__half* input1 = reinterpret_cast<__half*>(in1);
__half* input2 = reinterpret_cast<__half*>(in2);
__half* input3 = reinterpret_cast<__half*>(in3);
__half* output = reinterpret_cast<__half*>(data_out);
float rnum = 1.0f / batch_item_num;
extern MIGRAPHX_DEVICE_SHARED __half bufferh[];
__half* in_data_reduce = bufferh;
__half* in_data = bufferh + batch_item_num;
int start = blockIdx.x * batch_item_num;
auto rnum = 1.0f / batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
in_data[i] = input1[idx] + input2[idx] + input3[idx];
in_data_reduce[i] = in_data[i];
// in_data_reduce[i] = __half2float(in_data[i]) * rnum;
}
auto m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size);
m = m * rnum;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = in_data[i] - m;
in_data_reduce[i] = in_data[i] * in_data[i];
// in_data_reduce[i] = __half2float(in_data[i] * in_data[i]) * rnum;
in_data[i] = __float2half(__half2float(input1[idx]) + __half2float(input2[idx]) +
__half2float(input3[idx]));
}
m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size);
m = m * rnum + 1.0e-12f;
auto r = rsqrt(m);
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
// output[idx] = __half2float(in_data[i]) * r;
output[idx] = in_data[i] * r;
}
layernorm_kernel_half(in_data, in_data_reduce, output, batch_item_num, block_size, rnum);
}
void triadd_layernorm(hipStream_t stream,
......@@ -397,7 +355,7 @@ void triadd_layernorm(hipStream_t stream,
int block_num = in_s.elements() / batch_item_num;
int shared_size = batch_item_num * 2 * in_s.type_size();
half2_block_size = half2_block_size / 4;
triadd_layernorm_kernel_half2<<<block_num, half2_block_size, shared_size, stream>>>(
triadd_layernorm_half2<<<block_num, half2_block_size, shared_size, stream>>>(
arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num, half2_block_size);
}
else
......@@ -409,11 +367,11 @@ void triadd_layernorm(hipStream_t stream,
}
__global__ void
layernorm_kernel_half2(void* in1, void* data_out, index_int batch_item_num, index_int block_size)
layernorm_half2(void* in1, void* data_out, index_int batch_item_num, index_int block_size)
{
__half2* input1 = reinterpret_cast<__half2*>(in1);
__half2* output = reinterpret_cast<__half2*>(data_out);
auto rnum = __float2half2_rn(1.0f / batch_item_num);
float rnum = 1.0f / batch_item_num;
batch_item_num /= 2;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2;
......@@ -427,28 +385,28 @@ layernorm_kernel_half2(void* in1, void* data_out, index_int batch_item_num, inde
in_data_reduce[i] = in_data[i];
}
auto m =
block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
m = __hmul2(m, rnum);
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __hsub2(in_data[i], m);
in_data_reduce[i] = __hmul2(in_data[i], in_data[i]);
}
m = block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
m = __hmul2(m, rnum);
layernorm_kernel_half2(in_data, in_data_reduce, output, batch_item_num, block_size, rnum);
}
auto eps = __float2half2_rn(1.0e-12f);
auto r = __hadd2(m, eps);
r = h2rsqrt(r);
__global__ void
layernorm_half(void* in1, void* data_out, index_int batch_item_num, index_int block_size)
{
__half* input1 = reinterpret_cast<__half*>(in1);
__half* output = reinterpret_cast<__half*>(data_out);
float rnum = 1.0f / batch_item_num;
extern MIGRAPHX_DEVICE_SHARED __half buffer3[];
__half* in_data_reduce = buffer3;
__half* in_data = buffer3 + batch_item_num;
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
output[idx] = __hmul2(in_data[i], r);
int idx = i + start;
in_data[i] = input1[idx];
in_data_reduce[i] = in_data[i];
}
layernorm_kernel_half(in_data, in_data_reduce, output, batch_item_num, block_size, rnum);
}
void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
......@@ -462,7 +420,7 @@ void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
int block_num = in_s.elements() / batch_item_num;
int shared_size = batch_item_num * 2 * in_s.type_size();
half2_block_size = half2_block_size / 4;
layernorm_kernel_half2<<<block_num, half2_block_size, shared_size, stream>>>(
layernorm_half2<<<block_num, half2_block_size, shared_size, stream>>>(
arg1.data(), result.data(), batch_item_num, half2_block_size);
}
else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment