Commit 57914709 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 7da8748c
......@@ -54,26 +54,25 @@ __global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n)
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
int idb = tid % n_dim;
auto sum = __hadd2(ha[tid], hb[idb]);
int idb = tid % n_dim;
auto sum = __hadd2(ha[tid], hb[idb]);
__half2 sqrt2 = __float2half2_rn(M_SQRT1_2);
sum = __hmul2(sum, sqrt2);
auto f2 = __half22float2(sum);
sum = __hmul2(sum, sqrt2);
auto f2 = __half22float2(sum);
f2 += 1.0f;
auto h2 = __floats2half2_rn(f2.x, f2.y);
__half2 point5 = __float2half2_rn(0.5f);
hr[tid] = __hmul2(sum, __hmul2(point5, h2));
hr[tid] = __hmul2(sum, __hmul2(point5, h2));
}
}
void add_gelu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{
auto sr = result.get_shape();
auto sr = result.get_shape();
auto type = sr.type();
std::vector<shape> ss;
ss.push_back(arg1.get_shape());
......@@ -84,7 +83,8 @@ void add_gelu(hipStream_t stream,
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
add_gelu_kernel<<<block_num, block_size>>>(arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
add_gelu_kernel<<<block_num, block_size>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment