"docs/vscode:/vscode.git/clone" did not exist on "8cd50e3f28de941e69dd5b62754cc2475eb48d03"
Commit 57914709 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 7da8748c
...@@ -54,26 +54,25 @@ __global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n) ...@@ -54,26 +54,25 @@ __global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n)
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n) if(tid < n)
{ {
int idb = tid % n_dim; int idb = tid % n_dim;
auto sum = __hadd2(ha[tid], hb[idb]); auto sum = __hadd2(ha[tid], hb[idb]);
__half2 sqrt2 = __float2half2_rn(M_SQRT1_2); __half2 sqrt2 = __float2half2_rn(M_SQRT1_2);
sum = __hmul2(sum, sqrt2); sum = __hmul2(sum, sqrt2);
auto f2 = __half22float2(sum); auto f2 = __half22float2(sum);
f2 += 1.0f; f2 += 1.0f;
auto h2 = __floats2half2_rn(f2.x, f2.y); auto h2 = __floats2half2_rn(f2.x, f2.y);
__half2 point5 = __float2half2_rn(0.5f); __half2 point5 = __float2half2_rn(0.5f);
hr[tid] = __hmul2(sum, __hmul2(point5, h2)); hr[tid] = __hmul2(sum, __hmul2(point5, h2));
} }
} }
void add_gelu(hipStream_t stream, void add_gelu(hipStream_t stream,
const argument& result, const argument& result,
const argument& arg1, const argument& arg1,
const argument& arg2) const argument& arg2)
{ {
auto sr = result.get_shape(); auto sr = result.get_shape();
auto type = sr.type(); auto type = sr.type();
std::vector<shape> ss; std::vector<shape> ss;
ss.push_back(arg1.get_shape()); ss.push_back(arg1.get_shape());
...@@ -84,7 +83,8 @@ void add_gelu(hipStream_t stream, ...@@ -84,7 +83,8 @@ void add_gelu(hipStream_t stream,
auto last_dim = sr.lens().back() / 2; auto last_dim = sr.lens().back() / 2;
int block_size = 1024; int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size; int block_num = (elem_num + block_size - 1) / block_size;
add_gelu_kernel<<<block_num, block_size>>>(arg1.data(), arg2.data(), last_dim, result.data(), elem_num); add_gelu_kernel<<<block_num, block_size>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment