Commit a6477298 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent fe849702
...@@ -215,11 +215,14 @@ __device__ __half2 block_reduce_half2( ...@@ -215,11 +215,14 @@ __device__ __half2 block_reduce_half2(
// m = x - mean(x) // m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12) // m / sqrt(mean(m ^ 2) + 1e-12)
__device__ void layernorm_kernel_half2(__half2* in_data, __half2* in_data_reduce, __device__ void layernorm_kernel_half2(__half2* in_data,
__half2* out, index_int batch_item_num, index_int block_size, __half2* in_data_reduce,
__half2* out,
index_int batch_item_num,
index_int block_size,
float rbatch_num) float rbatch_num)
{ {
auto rnum = __float2half2_rn(rbatch_num); auto rnum = __float2half2_rn(rbatch_num);
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[]; extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
auto m = auto m =
block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{}); block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
...@@ -227,7 +230,7 @@ __device__ void layernorm_kernel_half2(__half2* in_data, __half2* in_data_reduce ...@@ -227,7 +230,7 @@ __device__ void layernorm_kernel_half2(__half2* in_data, __half2* in_data_reduce
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = __hsub2(in_data[i], m); in_data[i] = __hsub2(in_data[i], m);
in_data_reduce[i] = __hmul2(in_data[i], in_data[i]); in_data_reduce[i] = __hmul2(in_data[i], in_data[i]);
} }
...@@ -241,7 +244,7 @@ __device__ void layernorm_kernel_half2(__half2* in_data, __half2* in_data_reduce ...@@ -241,7 +244,7 @@ __device__ void layernorm_kernel_half2(__half2* in_data, __half2* in_data_reduce
int start = blockIdx.x * batch_item_num; int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
int idx = i + start; int idx = i + start;
out[idx] = __hmul2(in_data[i], r); out[idx] = __hmul2(in_data[i], r);
} }
} }
...@@ -253,7 +256,7 @@ __global__ void triadd_layernorm_half2( ...@@ -253,7 +256,7 @@ __global__ void triadd_layernorm_half2(
__half2* input2 = reinterpret_cast<__half2*>(in2); __half2* input2 = reinterpret_cast<__half2*>(in2);
__half2* input3 = reinterpret_cast<__half2*>(in3); __half2* input3 = reinterpret_cast<__half2*>(in3);
__half2* output = reinterpret_cast<__half2*>(data_out); __half2* output = reinterpret_cast<__half2*>(data_out);
float rnum = 1.0f / batch_item_num; float rnum = 1.0f / batch_item_num;
batch_item_num /= 2; batch_item_num /= 2;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[]; extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2; __half2* in_data_reduce = buffer2;
...@@ -289,15 +292,19 @@ block_reduce_half(T* buffer, index_int batch_item_num, index_int tid, index_int ...@@ -289,15 +292,19 @@ block_reduce_half(T* buffer, index_int batch_item_num, index_int tid, index_int
// m = x - mean(x) // m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12) // m / sqrt(mean(m ^ 2) + 1e-12)
__device__ void layernorm_kernel_half(__half* in_data, __half* in_data_reduce, __half* out, __device__ void layernorm_kernel_half(__half* in_data,
index_int batch_item_num, index_int block_size, float rnum) __half* in_data_reduce,
__half* out,
index_int batch_item_num,
index_int block_size,
float rnum)
{ {
auto m = block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size); auto m = block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size);
m *= rnum; m *= rnum;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = __float2half(__half2float(in_data[i]) - __half2float(m)); in_data[i] = __float2half(__half2float(in_data[i]) - __half2float(m));
in_data_reduce[i] = __float2half(__half2float(in_data[i]) * __half2float(in_data[i])); in_data_reduce[i] = __float2half(__half2float(in_data[i]) * __half2float(in_data[i]));
} }
...@@ -310,7 +317,7 @@ __device__ void layernorm_kernel_half(__half* in_data, __half* in_data_reduce, _ ...@@ -310,7 +317,7 @@ __device__ void layernorm_kernel_half(__half* in_data, __half* in_data_reduce, _
int start = blockIdx.x * batch_item_num; int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
int idx = i + start; int idx = i + start;
out[idx] = __float2half(__half2float(in_data[i]) * __half2float(r)); out[idx] = __float2half(__half2float(in_data[i]) * __half2float(r));
} }
} }
...@@ -324,7 +331,7 @@ __global__ void triadd_layernorm_half( ...@@ -324,7 +331,7 @@ __global__ void triadd_layernorm_half(
__half* input2 = reinterpret_cast<__half*>(in2); __half* input2 = reinterpret_cast<__half*>(in2);
__half* input3 = reinterpret_cast<__half*>(in3); __half* input3 = reinterpret_cast<__half*>(in3);
__half* output = reinterpret_cast<__half*>(data_out); __half* output = reinterpret_cast<__half*>(data_out);
float rnum = 1.0f / batch_item_num; float rnum = 1.0f / batch_item_num;
extern MIGRAPHX_DEVICE_SHARED __half bufferh[]; extern MIGRAPHX_DEVICE_SHARED __half bufferh[];
__half* in_data_reduce = bufferh; __half* in_data_reduce = bufferh;
__half* in_data = bufferh + batch_item_num; __half* in_data = bufferh + batch_item_num;
...@@ -332,8 +339,8 @@ __global__ void triadd_layernorm_half( ...@@ -332,8 +339,8 @@ __global__ void triadd_layernorm_half(
int start = blockIdx.x * batch_item_num; int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
int idx = i + start; int idx = i + start;
in_data[i] = __float2half(__half2float(input1[idx]) + __half2float(input2[idx]) + in_data[i] = __float2half(__half2float(input1[idx]) + __half2float(input2[idx]) +
__half2float(input3[idx])); __half2float(input3[idx]));
} }
...@@ -371,7 +378,7 @@ layernorm_half2(void* in1, void* data_out, index_int batch_item_num, index_int b ...@@ -371,7 +378,7 @@ layernorm_half2(void* in1, void* data_out, index_int batch_item_num, index_int b
{ {
__half2* input1 = reinterpret_cast<__half2*>(in1); __half2* input1 = reinterpret_cast<__half2*>(in1);
__half2* output = reinterpret_cast<__half2*>(data_out); __half2* output = reinterpret_cast<__half2*>(data_out);
float rnum = 1.0f / batch_item_num; float rnum = 1.0f / batch_item_num;
batch_item_num /= 2; batch_item_num /= 2;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[]; extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2; __half2* in_data_reduce = buffer2;
...@@ -393,7 +400,7 @@ layernorm_half(void* in1, void* data_out, index_int batch_item_num, index_int bl ...@@ -393,7 +400,7 @@ layernorm_half(void* in1, void* data_out, index_int batch_item_num, index_int bl
{ {
__half* input1 = reinterpret_cast<__half*>(in1); __half* input1 = reinterpret_cast<__half*>(in1);
__half* output = reinterpret_cast<__half*>(data_out); __half* output = reinterpret_cast<__half*>(data_out);
float rnum = 1.0f / batch_item_num; float rnum = 1.0f / batch_item_num;
extern MIGRAPHX_DEVICE_SHARED __half buffer3[]; extern MIGRAPHX_DEVICE_SHARED __half buffer3[];
__half* in_data_reduce = buffer3; __half* in_data_reduce = buffer3;
__half* in_data = buffer3 + batch_item_num; __half* in_data = buffer3 + batch_item_num;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment