Commit 5459e4d8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 2d5e45b8
......@@ -25,7 +25,7 @@ __global__ void add_kernel(void* a, void* b, int n_dim, void* r, int n)
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
int idb = tid % n_dim;
......@@ -42,10 +42,11 @@ void add(hipStream_t stream, const argument& result, const argument& arg1, const
if(sr.type() == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
add_kernel<<<block_num, block_size>>>(arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
add_kernel<<<block_num, block_size>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
{
......
......@@ -25,7 +25,7 @@ __global__ void mul_kernel(void* a, void* b, int n_dim, void* r, int n)
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
int idb = tid % n_dim;
......@@ -42,10 +42,11 @@ void mul(hipStream_t stream, const argument& result, const argument& arg1, const
if(sr.type() == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
mul_kernel<<<block_num, block_size>>>(arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
mul_kernel<<<block_num, block_size>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
{
......
......@@ -21,7 +21,7 @@ __global__ void mul_add_kernel_dim3(void* a, void* x, void* b, int dim3, void* r
if(id < n)
{
auto id1 = id % dim3;
hr[id] = __hadd2(__hmul2(ha[id], hx[id1]), hb[id1]);
hr[id] = __hadd2(__hmul2(ha[id], hx[id1]), hb[id1]);
}
}
......@@ -35,7 +35,7 @@ __global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int d
if(id < n)
{
int idb = id / factor + id % dim4;
hr[id] = __hadd2(__hmul2(ha[id], hx[id]), hb[idb]);
hr[id] = __hadd2(__hmul2(ha[id], hx[id]), hb[idb]);
}
}
......@@ -72,20 +72,22 @@ void mul_add(hipStream_t stream,
ss.push_back(arg3.get_shape());
if(type == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
auto lens = sr.lens();
int last_dim = lens.back() / 2;
auto n_dim = lens.size();
auto elem_num = sr.elements() / 2;
auto lens = sr.lens();
int last_dim = lens.back() / 2;
auto n_dim = lens.size();
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
if (n_dim == 2)
int block_num = (elem_num + block_size - 1) / block_size;
if(n_dim == 2)
{
mul_add_kernel_dim3<<<block_num, block_size>>>(arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num);
mul_add_kernel_dim3<<<block_num, block_size>>>(
arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num);
}
else
{
int factor = lens[1];
mul_add_kernel_dim4<<<block_num, block_size>>>(arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num);
mul_add_kernel_dim4<<<block_num, block_size>>>(
arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num);
}
}
else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment