Commit 562724bf authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 83f89182
...@@ -41,7 +41,7 @@ __global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides, ...@@ -41,7 +41,7 @@ __global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides,
{ {
__shared__ int shared_strides[18]; __shared__ int shared_strides[18];
int tid = threadIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z; int tid = threadIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z;
if (tid < 18) if(tid < 18)
{ {
shared_strides[tid] = strides[tid]; shared_strides[tid] = strides[tid];
} }
...@@ -52,12 +52,19 @@ __global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides, ...@@ -52,12 +52,19 @@ __global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides,
__half2* hx = reinterpret_cast<__half2*>(x); __half2* hx = reinterpret_cast<__half2*>(x);
__half2* hr = reinterpret_cast<__half2*>(r); __half2* hr = reinterpret_cast<__half2*>(r);
tid = tid + (blockIdx.x * (gridDim.y * gridDim.z) + blockIdx.y * gridDim.z + blockIdx.z) * blockDim.x * blockDim.y * blockDim.z; tid = tid + (blockIdx.x * (gridDim.y * gridDim.z) + blockIdx.y * gridDim.z + blockIdx.z) *
blockDim.x * blockDim.y * blockDim.z;
if(tid < elem_num) if(tid < elem_num)
{ {
int tida = shared_strides[1] * blockIdx.x + shared_strides[2] * blockIdx.y + shared_strides[3] * blockIdx.z + shared_strides[4] * threadIdx.x + shared_strides[5] * threadIdx.y + threadIdx.z; int tida = shared_strides[1] * blockIdx.x + shared_strides[2] * blockIdx.y +
int tidx = shared_strides[7] * blockIdx.x + shared_strides[8] * blockIdx.y + shared_strides[9] * blockIdx.z + shared_strides[10] * threadIdx.x + shared_strides[11] * threadIdx.y + threadIdx.z; shared_strides[3] * blockIdx.z + shared_strides[4] * threadIdx.x +
int tidb = shared_strides[13] * blockIdx.x + shared_strides[14] * blockIdx.y + shared_strides[15] * blockIdx.z + shared_strides[16] * threadIdx.x + shared_strides[17] * threadIdx.y + threadIdx.z; shared_strides[5] * threadIdx.y + threadIdx.z;
int tidx = shared_strides[7] * blockIdx.x + shared_strides[8] * blockIdx.y +
shared_strides[9] * blockIdx.z + shared_strides[10] * threadIdx.x +
shared_strides[11] * threadIdx.y + threadIdx.z;
int tidb = shared_strides[13] * blockIdx.x + shared_strides[14] * blockIdx.y +
shared_strides[15] * blockIdx.z + shared_strides[16] * threadIdx.x +
shared_strides[17] * threadIdx.y + threadIdx.z;
hr[tid] = __hadd2(__hmul2(ha[tida], hx[tidx]), hb[tidb]); hr[tid] = __hadd2(__hmul2(ha[tida], hx[tidx]), hb[tidb]);
} }
} }
...@@ -89,28 +96,28 @@ void mul_add(hipStream_t stream, ...@@ -89,28 +96,28 @@ void mul_add(hipStream_t stream,
const argument& arg2, const argument& arg2,
const argument& arg3) const argument& arg3)
{ {
auto sr = result.get_shape(); auto sr = result.get_shape();
auto s1 = arg1.get_shape(); auto s1 = arg1.get_shape();
auto s2 = arg2.get_shape(); auto s2 = arg2.get_shape();
auto s3 = arg3.get_shape(); auto s3 = arg3.get_shape();
auto type = sr.type(); auto type = sr.type();
if(type == sr.type()) if(type == sr.type())
{ {
hip_visit_all(result, arg1, arg2, arg3, sr, s1, s2, s3)([&](auto r, auto i1, auto i2, auto i3, hip_visit_all(result, arg1, arg2, arg3, sr, s1, s2, s3)(
auto dsr, auto ds1, auto ds2, auto ds3) { [&](auto r, auto i1, auto i2, auto i3, auto dsr, auto ds1, auto ds2, auto ds3) {
__half2* rp = reinterpret_cast<__half2*>(r.data()); __half2* rp = reinterpret_cast<__half2*>(r.data());
__half2* i1p = reinterpret_cast<__half2*>(i1.data()); __half2* i1p = reinterpret_cast<__half2*>(i1.data());
__half2* i2p = reinterpret_cast<__half2*>(i2.data()); __half2* i2p = reinterpret_cast<__half2*>(i2.data());
__half2* i3p = reinterpret_cast<__half2*>(i3.data()); __half2* i3p = reinterpret_cast<__half2*>(i3.data());
gs_launch(stream, sr.elements() / 2)([=](auto i) __device__ { gs_launch(stream, sr.elements() / 2)([=](auto i) __device__ {
auto idx = dsr.multi(i); auto idx = dsr.multi(i);
auto idx1 = ds1.index(idx); auto idx1 = ds1.index(idx);
auto idx2 = ds2.index(idx); auto idx2 = ds2.index(idx);
auto idx3 = ds3.index(idx); auto idx3 = ds3.index(idx);
rp[i] = __hadd2(__hmul2(i1p[idx1], i2p[idx2]), i3p[idx3]); rp[i] = __hadd2(__hmul2(i1p[idx1], i2p[idx2]), i3p[idx3]);
});
}); });
});
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment