Commit a6477298 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent fe849702
......@@ -215,11 +215,14 @@ __device__ __half2 block_reduce_half2(
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
__device__ void layernorm_kernel_half2(__half2* in_data, __half2* in_data_reduce,
__half2* out, index_int batch_item_num, index_int block_size,
__device__ void layernorm_kernel_half2(__half2* in_data,
__half2* in_data_reduce,
__half2* out,
index_int batch_item_num,
index_int block_size,
float rbatch_num)
{
auto rnum = __float2half2_rn(rbatch_num);
auto rnum = __float2half2_rn(rbatch_num);
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
auto m =
block_reduce_half2(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
......@@ -227,7 +230,7 @@ __device__ void layernorm_kernel_half2(__half2* in_data, __half2* in_data_reduce
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __hsub2(in_data[i], m);
in_data[i] = __hsub2(in_data[i], m);
in_data_reduce[i] = __hmul2(in_data[i], in_data[i]);
}
......@@ -241,7 +244,7 @@ __device__ void layernorm_kernel_half2(__half2* in_data, __half2* in_data_reduce
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
int idx = i + start;
out[idx] = __hmul2(in_data[i], r);
}
}
......@@ -253,7 +256,7 @@ __global__ void triadd_layernorm_half2(
__half2* input2 = reinterpret_cast<__half2*>(in2);
__half2* input3 = reinterpret_cast<__half2*>(in3);
__half2* output = reinterpret_cast<__half2*>(data_out);
float rnum = 1.0f / batch_item_num;
float rnum = 1.0f / batch_item_num;
batch_item_num /= 2;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2;
......@@ -289,15 +292,19 @@ block_reduce_half(T* buffer, index_int batch_item_num, index_int tid, index_int
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
__device__ void layernorm_kernel_half(__half* in_data, __half* in_data_reduce, __half* out,
index_int batch_item_num, index_int block_size, float rnum)
__device__ void layernorm_kernel_half(__half* in_data,
__half* in_data_reduce,
__half* out,
index_int batch_item_num,
index_int block_size,
float rnum)
{
auto m = block_reduce_half(in_data_reduce, batch_item_num, threadIdx.x, block_size);
m *= rnum;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __float2half(__half2float(in_data[i]) - __half2float(m));
in_data[i] = __float2half(__half2float(in_data[i]) - __half2float(m));
in_data_reduce[i] = __float2half(__half2float(in_data[i]) * __half2float(in_data[i]));
}
......@@ -310,7 +317,7 @@ __device__ void layernorm_kernel_half(__half* in_data, __half* in_data_reduce, _
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
int idx = i + start;
out[idx] = __float2half(__half2float(in_data[i]) * __half2float(r));
}
}
......@@ -324,7 +331,7 @@ __global__ void triadd_layernorm_half(
__half* input2 = reinterpret_cast<__half*>(in2);
__half* input3 = reinterpret_cast<__half*>(in3);
__half* output = reinterpret_cast<__half*>(data_out);
float rnum = 1.0f / batch_item_num;
float rnum = 1.0f / batch_item_num;
extern MIGRAPHX_DEVICE_SHARED __half bufferh[];
__half* in_data_reduce = bufferh;
__half* in_data = bufferh + batch_item_num;
......@@ -332,8 +339,8 @@ __global__ void triadd_layernorm_half(
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
int idx = i + start;
in_data[i] = __float2half(__half2float(input1[idx]) + __half2float(input2[idx]) +
int idx = i + start;
in_data[i] = __float2half(__half2float(input1[idx]) + __half2float(input2[idx]) +
__half2float(input3[idx]));
}
......@@ -371,7 +378,7 @@ layernorm_half2(void* in1, void* data_out, index_int batch_item_num, index_int b
{
__half2* input1 = reinterpret_cast<__half2*>(in1);
__half2* output = reinterpret_cast<__half2*>(data_out);
float rnum = 1.0f / batch_item_num;
float rnum = 1.0f / batch_item_num;
batch_item_num /= 2;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2;
......@@ -393,7 +400,7 @@ layernorm_half(void* in1, void* data_out, index_int batch_item_num, index_int bl
{
__half* input1 = reinterpret_cast<__half*>(in1);
__half* output = reinterpret_cast<__half*>(data_out);
float rnum = 1.0f / batch_item_num;
float rnum = 1.0f / batch_item_num;
extern MIGRAPHX_DEVICE_SHARED __half buffer3[];
__half* in_data_reduce = buffer3;
__half* in_data = buffer3 + batch_item_num;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment