Commit 562724bf authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 83f89182
......@@ -41,7 +41,7 @@ __global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides,
{
__shared__ int shared_strides[18];
int tid = threadIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z;
if (tid < 18)
if(tid < 18)
{
shared_strides[tid] = strides[tid];
}
......@@ -52,12 +52,19 @@ __global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides,
__half2* hx = reinterpret_cast<__half2*>(x);
__half2* hr = reinterpret_cast<__half2*>(r);
tid = tid + (blockIdx.x * (gridDim.y * gridDim.z) + blockIdx.y * gridDim.z + blockIdx.z) * blockDim.x * blockDim.y * blockDim.z;
tid = tid + (blockIdx.x * (gridDim.y * gridDim.z) + blockIdx.y * gridDim.z + blockIdx.z) *
blockDim.x * blockDim.y * blockDim.z;
if(tid < elem_num)
{
int tida = shared_strides[1] * blockIdx.x + shared_strides[2] * blockIdx.y + shared_strides[3] * blockIdx.z + shared_strides[4] * threadIdx.x + shared_strides[5] * threadIdx.y + threadIdx.z;
int tidx = shared_strides[7] * blockIdx.x + shared_strides[8] * blockIdx.y + shared_strides[9] * blockIdx.z + shared_strides[10] * threadIdx.x + shared_strides[11] * threadIdx.y + threadIdx.z;
int tidb = shared_strides[13] * blockIdx.x + shared_strides[14] * blockIdx.y + shared_strides[15] * blockIdx.z + shared_strides[16] * threadIdx.x + shared_strides[17] * threadIdx.y + threadIdx.z;
int tida = shared_strides[1] * blockIdx.x + shared_strides[2] * blockIdx.y +
shared_strides[3] * blockIdx.z + shared_strides[4] * threadIdx.x +
shared_strides[5] * threadIdx.y + threadIdx.z;
int tidx = shared_strides[7] * blockIdx.x + shared_strides[8] * blockIdx.y +
shared_strides[9] * blockIdx.z + shared_strides[10] * threadIdx.x +
shared_strides[11] * threadIdx.y + threadIdx.z;
int tidb = shared_strides[13] * blockIdx.x + shared_strides[14] * blockIdx.y +
shared_strides[15] * blockIdx.z + shared_strides[16] * threadIdx.x +
shared_strides[17] * threadIdx.y + threadIdx.z;
hr[tid] = __hadd2(__hmul2(ha[tida], hx[tidx]), hb[tidb]);
}
}
......@@ -89,28 +96,28 @@ void mul_add(hipStream_t stream,
const argument& arg2,
const argument& arg3)
{
auto sr = result.get_shape();
auto s1 = arg1.get_shape();
auto s2 = arg2.get_shape();
auto s3 = arg3.get_shape();
auto sr = result.get_shape();
auto s1 = arg1.get_shape();
auto s2 = arg2.get_shape();
auto s3 = arg3.get_shape();
auto type = sr.type();
if(type == sr.type())
{
hip_visit_all(result, arg1, arg2, arg3, sr, s1, s2, s3)([&](auto r, auto i1, auto i2, auto i3,
auto dsr, auto ds1, auto ds2, auto ds3) {
__half2* rp = reinterpret_cast<__half2*>(r.data());
__half2* i1p = reinterpret_cast<__half2*>(i1.data());
__half2* i2p = reinterpret_cast<__half2*>(i2.data());
__half2* i3p = reinterpret_cast<__half2*>(i3.data());
gs_launch(stream, sr.elements() / 2)([=](auto i) __device__ {
auto idx = dsr.multi(i);
auto idx1 = ds1.index(idx);
auto idx2 = ds2.index(idx);
auto idx3 = ds3.index(idx);
rp[i] = __hadd2(__hmul2(i1p[idx1], i2p[idx2]), i3p[idx3]);
hip_visit_all(result, arg1, arg2, arg3, sr, s1, s2, s3)(
[&](auto r, auto i1, auto i2, auto i3, auto dsr, auto ds1, auto ds2, auto ds3) {
__half2* rp = reinterpret_cast<__half2*>(r.data());
__half2* i1p = reinterpret_cast<__half2*>(i1.data());
__half2* i2p = reinterpret_cast<__half2*>(i2.data());
__half2* i3p = reinterpret_cast<__half2*>(i3.data());
gs_launch(stream, sr.elements() / 2)([=](auto i) __device__ {
auto idx = dsr.multi(i);
auto idx1 = ds1.index(idx);
auto idx2 = ds2.index(idx);
auto idx3 = ds3.index(idx);
rp[i] = __hadd2(__hmul2(i1p[idx1], i2p[idx2]), i3p[idx3]);
});
});
});
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment