Commit 5459e4d8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 2d5e45b8
...@@ -25,7 +25,7 @@ __global__ void add_kernel(void* a, void* b, int n_dim, void* r, int n) ...@@ -25,7 +25,7 @@ __global__ void add_kernel(void* a, void* b, int n_dim, void* r, int n)
__half2* ha = reinterpret_cast<__half2*>(a); __half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b); __half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r); __half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n) if(tid < n)
{ {
int idb = tid % n_dim; int idb = tid % n_dim;
...@@ -42,10 +42,11 @@ void add(hipStream_t stream, const argument& result, const argument& arg1, const ...@@ -42,10 +42,11 @@ void add(hipStream_t stream, const argument& result, const argument& arg1, const
if(sr.type() == shape::half_type and is_bert(ss)) if(sr.type() == shape::half_type and is_bert(ss))
{ {
auto elem_num = sr.elements() / 2; auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2; auto last_dim = sr.lens().back() / 2;
int block_size = 1024; int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size; int block_num = (elem_num + block_size - 1) / block_size;
add_kernel<<<block_num, block_size>>>(arg1.data(), arg2.data(), last_dim, result.data(), elem_num); add_kernel<<<block_num, block_size>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
} }
else else
{ {
......
...@@ -25,7 +25,7 @@ __global__ void mul_kernel(void* a, void* b, int n_dim, void* r, int n) ...@@ -25,7 +25,7 @@ __global__ void mul_kernel(void* a, void* b, int n_dim, void* r, int n)
__half2* ha = reinterpret_cast<__half2*>(a); __half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b); __half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r); __half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n) if(tid < n)
{ {
int idb = tid % n_dim; int idb = tid % n_dim;
...@@ -42,10 +42,11 @@ void mul(hipStream_t stream, const argument& result, const argument& arg1, const ...@@ -42,10 +42,11 @@ void mul(hipStream_t stream, const argument& result, const argument& arg1, const
if(sr.type() == shape::half_type and is_bert(ss)) if(sr.type() == shape::half_type and is_bert(ss))
{ {
auto elem_num = sr.elements() / 2; auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2; auto last_dim = sr.lens().back() / 2;
int block_size = 1024; int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size; int block_num = (elem_num + block_size - 1) / block_size;
mul_kernel<<<block_num, block_size>>>(arg1.data(), arg2.data(), last_dim, result.data(), elem_num); mul_kernel<<<block_num, block_size>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
} }
else else
{ {
......
...@@ -21,7 +21,7 @@ __global__ void mul_add_kernel_dim3(void* a, void* x, void* b, int dim3, void* r ...@@ -21,7 +21,7 @@ __global__ void mul_add_kernel_dim3(void* a, void* x, void* b, int dim3, void* r
if(id < n) if(id < n)
{ {
auto id1 = id % dim3; auto id1 = id % dim3;
hr[id] = __hadd2(__hmul2(ha[id], hx[id1]), hb[id1]); hr[id] = __hadd2(__hmul2(ha[id], hx[id1]), hb[id1]);
} }
} }
...@@ -35,7 +35,7 @@ __global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int d ...@@ -35,7 +35,7 @@ __global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int d
if(id < n) if(id < n)
{ {
int idb = id / factor + id % dim4; int idb = id / factor + id % dim4;
hr[id] = __hadd2(__hmul2(ha[id], hx[id]), hb[idb]); hr[id] = __hadd2(__hmul2(ha[id], hx[id]), hb[idb]);
} }
} }
...@@ -72,20 +72,22 @@ void mul_add(hipStream_t stream, ...@@ -72,20 +72,22 @@ void mul_add(hipStream_t stream,
ss.push_back(arg3.get_shape()); ss.push_back(arg3.get_shape());
if(type == shape::half_type and is_bert(ss)) if(type == shape::half_type and is_bert(ss))
{ {
auto elem_num = sr.elements() / 2; auto elem_num = sr.elements() / 2;
auto lens = sr.lens(); auto lens = sr.lens();
int last_dim = lens.back() / 2; int last_dim = lens.back() / 2;
auto n_dim = lens.size(); auto n_dim = lens.size();
int block_size = 1024; int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size; int block_num = (elem_num + block_size - 1) / block_size;
if (n_dim == 2) if(n_dim == 2)
{ {
mul_add_kernel_dim3<<<block_num, block_size>>>(arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num); mul_add_kernel_dim3<<<block_num, block_size>>>(
arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num);
} }
else else
{ {
int factor = lens[1]; int factor = lens[1];
mul_add_kernel_dim4<<<block_num, block_size>>>(arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num); mul_add_kernel_dim4<<<block_num, block_size>>>(
arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num);
} }
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment