Commit 6c834296 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

use fma for the mul_add and refine add_gelu implementation

parent 9e5c56da
...@@ -57,11 +57,15 @@ __global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n) ...@@ -57,11 +57,15 @@ __global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n)
int idb = tid % n_dim; int idb = tid % n_dim;
auto sum = __hadd2(ha[tid], hb[idb]); auto sum = __hadd2(ha[tid], hb[idb]);
__half2 sqrt2 = __float2half2_rn(M_SQRT1_2); __half2 sqrt2 = __float2half2_rn(M_SQRT1_2);
sum = __hmul2(sum, sqrt2); auto x = __hmul2(sum, sqrt2);
auto f2 = __half22float2(sum); auto f2 = __half22float2(x);
f2 += 1.0f; f2.x = ::erf(f2.x);
f2.y = ::erf(f2.y);
auto h2 = __floats2half2_rn(f2.x, f2.y); auto h2 = __floats2half2_rn(f2.x, f2.y);
auto one = __float2half2_rn(1.0f);
h2 = __hadd2(h2, one);
__half2 point5 = __float2half2_rn(0.5f); __half2 point5 = __float2half2_rn(0.5f);
hr[tid] = __hmul2(sum, __hmul2(point5, h2)); hr[tid] = __hmul2(sum, __hmul2(point5, h2));
} }
...@@ -83,7 +87,7 @@ void add_gelu(hipStream_t stream, ...@@ -83,7 +87,7 @@ void add_gelu(hipStream_t stream,
auto last_dim = sr.lens().back() / 2; auto last_dim = sr.lens().back() / 2;
int block_size = 1024; int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size; int block_num = (elem_num + block_size - 1) / block_size;
add_gelu_kernel<<<block_num, block_size>>>( add_gelu_kernel<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num); arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
} }
else else
......
...@@ -35,7 +35,7 @@ __global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int d ...@@ -35,7 +35,7 @@ __global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int d
if(id < n) if(id < n)
{ {
int idb = id / (factor * dim4) * dim4 + id % dim4; int idb = id / (factor * dim4) * dim4 + id % dim4;
hr[id] = __hadd2(__hmul2(ha[id], hx[id]), hb[idb]); hr[id] = __hfma2(ha[id], hx[id], hb[idb]);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment