"vscode:/vscode.git/clone" did not exist on "1ff6c73e07eb83ac0e4017b60589c52475b4157e"
Commit 562724bf authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 83f89182
......@@ -41,7 +41,7 @@ __global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides,
{
__shared__ int shared_strides[18];
int tid = threadIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z;
if (tid < 18)
if(tid < 18)
{
shared_strides[tid] = strides[tid];
}
......@@ -52,12 +52,19 @@ __global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides,
__half2* hx = reinterpret_cast<__half2*>(x);
__half2* hr = reinterpret_cast<__half2*>(r);
tid = tid + (blockIdx.x * (gridDim.y * gridDim.z) + blockIdx.y * gridDim.z + blockIdx.z) * blockDim.x * blockDim.y * blockDim.z;
tid = tid + (blockIdx.x * (gridDim.y * gridDim.z) + blockIdx.y * gridDim.z + blockIdx.z) *
blockDim.x * blockDim.y * blockDim.z;
if(tid < elem_num)
{
int tida = shared_strides[1] * blockIdx.x + shared_strides[2] * blockIdx.y + shared_strides[3] * blockIdx.z + shared_strides[4] * threadIdx.x + shared_strides[5] * threadIdx.y + threadIdx.z;
int tidx = shared_strides[7] * blockIdx.x + shared_strides[8] * blockIdx.y + shared_strides[9] * blockIdx.z + shared_strides[10] * threadIdx.x + shared_strides[11] * threadIdx.y + threadIdx.z;
int tidb = shared_strides[13] * blockIdx.x + shared_strides[14] * blockIdx.y + shared_strides[15] * blockIdx.z + shared_strides[16] * threadIdx.x + shared_strides[17] * threadIdx.y + threadIdx.z;
int tida = shared_strides[1] * blockIdx.x + shared_strides[2] * blockIdx.y +
shared_strides[3] * blockIdx.z + shared_strides[4] * threadIdx.x +
shared_strides[5] * threadIdx.y + threadIdx.z;
int tidx = shared_strides[7] * blockIdx.x + shared_strides[8] * blockIdx.y +
shared_strides[9] * blockIdx.z + shared_strides[10] * threadIdx.x +
shared_strides[11] * threadIdx.y + threadIdx.z;
int tidb = shared_strides[13] * blockIdx.x + shared_strides[14] * blockIdx.y +
shared_strides[15] * blockIdx.z + shared_strides[16] * threadIdx.x +
shared_strides[17] * threadIdx.y + threadIdx.z;
hr[tid] = __hadd2(__hmul2(ha[tida], hx[tidx]), hb[tidb]);
}
}
......@@ -97,8 +104,8 @@ void mul_add(hipStream_t stream,
if(type == sr.type())
{
hip_visit_all(result, arg1, arg2, arg3, sr, s1, s2, s3)([&](auto r, auto i1, auto i2, auto i3,
auto dsr, auto ds1, auto ds2, auto ds3) {
hip_visit_all(result, arg1, arg2, arg3, sr, s1, s2, s3)(
[&](auto r, auto i1, auto i2, auto i3, auto dsr, auto ds1, auto ds2, auto ds3) {
__half2* rp = reinterpret_cast<__half2*>(r.data());
__half2* i1p = reinterpret_cast<__half2*>(i1.data());
__half2* i2p = reinterpret_cast<__half2*>(i2.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment