"vscode:/vscode.git/clone" did not exist on "b436213ec35fda21fd6260b2ce8b5d212c5072fc"
Commit 780fffc8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent fc48a1d3
...@@ -378,20 +378,20 @@ __global__ void triadd_layernorm_kernel( ...@@ -378,20 +378,20 @@ __global__ void triadd_layernorm_kernel(
} }
auto m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size); auto m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size);
m = m * rnum; m = m * rnum;
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = in_data[i] - m; in_data[i] = in_data[i] - m;
in_data_reduce[i] = in_data[i] * in_data[i]; in_data_reduce[i] = in_data[i] * in_data[i];
// in_data_reduce[i] = __half2float(in_data[i] * in_data[i]) * rnum; // in_data_reduce[i] = __half2float(in_data[i] * in_data[i]) * rnum;
} }
m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size); m = block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size);
m = m * rnum + 1.0e-12f; m = m * rnum + 1.0e-12f;
auto r = rsqrt(m); auto r = rsqrt(m);
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
int idx = i + start; int idx = i + start;
// output[idx] = __half2float(in_data[i]) * r; // output[idx] = __half2float(in_data[i]) * r;
output[idx] = in_data[i] * r; output[idx] = in_data[i] * r;
} }
...@@ -423,8 +423,8 @@ void triadd_layernorm(hipStream_t stream, ...@@ -423,8 +423,8 @@ void triadd_layernorm(hipStream_t stream,
} }
} }
__global__ void layernorm_kernel_half2( __global__ void
void* in1, void* data_out, index_int batch_item_num, index_int block_size) layernorm_kernel_half2(void* in1, void* data_out, index_int batch_item_num, index_int block_size)
{ {
__half2* input1 = reinterpret_cast<__half2*>(in1); __half2* input1 = reinterpret_cast<__half2*>(in1);
__half2* output = reinterpret_cast<__half2*>(data_out); __half2* output = reinterpret_cast<__half2*>(data_out);
...@@ -448,7 +448,7 @@ __global__ void layernorm_kernel_half2( ...@@ -448,7 +448,7 @@ __global__ void layernorm_kernel_half2(
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = __hsub2(in_data[i], m); in_data[i] = __hsub2(in_data[i], m);
in_data_reduce[i] = __hmul2(in_data[i], in_data[i]); in_data_reduce[i] = __hmul2(in_data[i], in_data[i]);
} }
...@@ -483,7 +483,7 @@ void layernorm(hipStream_t stream, const argument& result, const argument& arg1) ...@@ -483,7 +483,7 @@ void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
else else
{ {
layernorm_fusion(stream, result, arg1)([](auto x) { return x; }, layernorm_fusion(stream, result, arg1)([](auto x) { return x; },
[](auto x, auto& y, auto) { y = x; }); [](auto x, auto& y, auto) { y = x; });
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment