Commit 580673a0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 80a6ca93
...@@ -258,7 +258,7 @@ __global__ void triadd_layernorm_kernel_half2( ...@@ -258,7 +258,7 @@ __global__ void triadd_layernorm_kernel_half2(
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{ {
in_data[i] = __hsub2(in_data[i], m); in_data[i] = __hsub2(in_data[i], m);
// in_data_reduce[i] = __hmul2(__hmul2(in_data[i], in_data[i]), rnum); // in_data_reduce[i] = __hmul2(__hmul2(in_data[i], in_data[i]), rnum);
in_data_reduce[i] = __hmul2(in_data[i], in_data[i]); in_data_reduce[i] = __hmul2(in_data[i], in_data[i]);
} }
...@@ -383,7 +383,8 @@ __global__ void triadd_layernorm_kernel( ...@@ -383,7 +383,8 @@ __global__ void triadd_layernorm_kernel(
in_data_reduce[i] = __half2float(in_data[i] * in_data[i]) * rnum; in_data_reduce[i] = __half2float(in_data[i] * in_data[i]) * rnum;
} }
m = __half2float(block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size)) + 1.0e-12f; m = __half2float(block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size)) +
1.0e-12f;
auto r = rsqrt(m); auto r = rsqrt(m);
for(int i = threadIdx.x; i < batch_item_num; i += block_size) for(int i = threadIdx.x; i < batch_item_num; i += block_size)
...@@ -409,7 +410,8 @@ void triadd_layernorm(hipStream_t stream, ...@@ -409,7 +410,8 @@ void triadd_layernorm(hipStream_t stream,
// int shared_size = batch_item_num * 2 * in_s.type_size(); // int shared_size = batch_item_num * 2 * in_s.type_size();
// half2_block_size = half2_block_size / 4; // half2_block_size = half2_block_size / 4;
// triadd_layernorm_kernel_half2<<<block_num, half2_block_size, shared_size, stream>>>( // triadd_layernorm_kernel_half2<<<block_num, half2_block_size, shared_size, stream>>>(
// arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num, half2_block_size); // arg1.data(), arg2.data(), arg3.data(), result.data(), batch_item_num,
// half2_block_size);
// } // }
// if(type == shape::half_type and (batch_item_num % 2) == 0) // if(type == shape::half_type and (batch_item_num % 2) == 0)
if(type == shape::half_type) if(type == shape::half_type)
...@@ -418,13 +420,13 @@ void triadd_layernorm(hipStream_t stream, ...@@ -418,13 +420,13 @@ void triadd_layernorm(hipStream_t stream,
int block_num = in_s.elements() / batch_item_num; int block_num = in_s.elements() / batch_item_num;
int shared_size = batch_item_num * 2 * in_s.type_size(); int shared_size = batch_item_num * 2 * in_s.type_size();
reduce_block_size = reduce_block_size / 2; reduce_block_size = reduce_block_size / 2;
triadd_layernorm_kernel<__half><<<block_num, reduce_block_size, shared_size, stream>>>( triadd_layernorm_kernel<__half>
arg1.data(), <<<block_num, reduce_block_size, shared_size, stream>>>(arg1.data(),
arg2.data(), arg2.data(),
arg3.data(), arg3.data(),
result.data(), result.data(),
batch_item_num, batch_item_num,
reduce_block_size); reduce_block_size);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment