Commit 2d5e45b8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup kernel refinement for add, mul, and mul_add

parent 16e5b5d0
...@@ -8,27 +8,44 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -8,27 +8,44 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
__global__ void add_kernel(__half* a, __half* b, __half* r, int n) static bool is_bert(const std::vector<shape>& ss)
{ {
auto n_dim = ss.front().lens().size();
if(n_dim == 2)
{
auto stride = ss.at(1).strides();
return (stride[0] == 0);
}
return false;
}
__global__ void add_kernel(void* a, void* b, int n_dim, void* r, int n)
{
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n) if(tid < n)
{ {
r[tid] = a[tid] + b[tid % 768]; int idb = tid % n_dim;
hr[tid] = __hadd2(ha[tid], hb[idb]);
} }
} }
void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
auto s2 = arg2.get_shape(); auto sr = result.get_shape();
if(s2.element_space() == 768 and s2.type() == shape::half_type) std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
if(sr.type() == shape::half_type and is_bert(ss))
{ {
auto elem_num = s2.elements(); auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024; int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size; int block_num = (elem_num + block_size - 1) / block_size;
add_kernel<<<block_num, block_size>>>(reinterpret_cast<__half*>(arg1.data()), add_kernel<<<block_num, block_size>>>(arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
reinterpret_cast<__half*>(arg2.data()),
reinterpret_cast<__half*>(result.data()),
elem_num);
} }
else else
{ {
......
...@@ -8,44 +8,51 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -8,44 +8,51 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
__global__ void mul_kernel(__half* a, __half* b, __half* r, int n) static bool is_bert(const std::vector<shape>& ss)
{ {
auto n_dim = ss.front().lens().size();
if(n_dim == 2)
{
auto stride = ss.at(1).strides();
return (stride[0] == 0);
}
return false;
}
__global__ void mul_kernel(void* a, void* b, int n_dim, void* r, int n)
{
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n) if(tid < n)
{ {
r[tid] = a[tid] * b[tid % 768]; int idb = tid % n_dim;
hr[tid] = __hmul2(ha[tid], hb[idb]);
} }
} }
void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2) void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
auto s2 = arg2.get_shape(); auto sr = result.get_shape();
if(s2.element_space() == 768 and s2.type() == shape::half_type) std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
if(sr.type() == shape::half_type and is_bert(ss))
{ {
auto elem_num = s2.elements(); auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024; int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size; int block_num = (elem_num + block_size - 1) / block_size;
mul_kernel<<<block_num, block_size>>>(reinterpret_cast<__half*>(arg1.data()), mul_kernel<<<block_num, block_size>>>(arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
reinterpret_cast<__half*>(arg2.data()),
reinterpret_cast<__half*>(result.data()),
elem_num);
} }
else else
{ {
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x * y; }); nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; });
} }
} }
void mul(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z)
__device__ { return x * y * z; });
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -11,84 +11,51 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,84 +11,51 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
//__global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int n) __global__ void mul_add_kernel_dim3(void* a, void* x, void* b, int dim3, void* r, int n)
//{
// int id = blockDim.x * blockIdx.x + threadIdx.x;
// __half* ha = reinterpret_cast<__half*>(a);
// __half* hb = reinterpret_cast<__half*>(b);
// __half* hx = reinterpret_cast<__half*>(x);
// __half* hr = reinterpret_cast<__half*>(r);
// if (id < n)
// {
// hr[id] = __float2half(__half2float(ha[id]) * __half2float(hx[id]) + __half2float(hb[id]));
// }
//}
// __global__ void mul_add_kernel(void* a, int an, void* x, int xn, void* b, int bn, void* r, int n)
// {
// int id = blockDim.x * blockIdx.x + threadIdx.x;
// __half2* ha = reinterpret_cast<__half2*>(a);
// __half2* hb = reinterpret_cast<__half2*>(b);
// __half2* hx = reinterpret_cast<__half2*>(x);
// __half2* hr = reinterpret_cast<__half2*>(r);
// if(id < n)
// {
// hr[id] = __hadd2(__hmul2(ha[id % an], hx[id % xn]), hb[id % bn]);
// }
// }
__global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides, int elem_num)
{ {
__shared__ int shared_strides[18]; int id = blockDim.x * blockIdx.x + threadIdx.x;
int tid = threadIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z; __half2* ha = reinterpret_cast<__half2*>(a);
if(tid < 18) __half2* hb = reinterpret_cast<__half2*>(b);
__half2* hx = reinterpret_cast<__half2*>(x);
__half2* hr = reinterpret_cast<__half2*>(r);
if(id < n)
{ {
shared_strides[tid] = strides[tid]; auto id1 = id % dim3;
hr[id] = __hadd2(__hmul2(ha[id], hx[id1]), hb[id1]);
} }
__syncthreads(); }
__global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int dim4, void* r, int n)
{
int id = blockDim.x * blockIdx.x + threadIdx.x;
__half2* ha = reinterpret_cast<__half2*>(a); __half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b); __half2* hb = reinterpret_cast<__half2*>(b);
__half2* hx = reinterpret_cast<__half2*>(x); __half2* hx = reinterpret_cast<__half2*>(x);
__half2* hr = reinterpret_cast<__half2*>(r); __half2* hr = reinterpret_cast<__half2*>(r);
if(id < n)
tid = tid + (blockIdx.x * (gridDim.y * gridDim.z) + blockIdx.y * gridDim.z + blockIdx.z) *
blockDim.x * blockDim.y * blockDim.z;
if(tid < elem_num)
{ {
int tida = shared_strides[1] * blockIdx.x + shared_strides[2] * blockIdx.y + int idb = id / factor + id % dim4;
shared_strides[3] * blockIdx.z + shared_strides[4] * threadIdx.x + hr[id] = __hadd2(__hmul2(ha[id], hx[id]), hb[idb]);
shared_strides[5] * threadIdx.y + threadIdx.z;
int tidx = shared_strides[7] * blockIdx.x + shared_strides[8] * blockIdx.y +
shared_strides[9] * blockIdx.z + shared_strides[10] * threadIdx.x +
shared_strides[11] * threadIdx.y + threadIdx.z;
int tidb = shared_strides[13] * blockIdx.x + shared_strides[14] * blockIdx.y +
shared_strides[15] * blockIdx.z + shared_strides[16] * threadIdx.x +
shared_strides[17] * threadIdx.y + threadIdx.z;
hr[tid] = __hadd2(__hmul2(ha[tida], hx[tidx]), hb[tidb]);
} }
} }
// void mul_add(hipStream_t stream, static bool is_bert(const std::vector<shape>& ss)
// const argument& result, {
// const argument& arg1, auto n_dim = ss.front().lens().size();
// const argument& arg2, if(n_dim == 3)
// const argument& arg3) {
// { auto stride = ss.at(2).strides();
// auto type = result.get_shape().type(); return (stride[1] == 0);
// if(type == shape::half_type) }
// { else if(n_dim == 2)
// std::cout << "case1" << std::endl; {
// mul_add_kernel<<<block_num, block_size>>>( auto stride1 = ss.at(1).strides();
// arg1.data(), s1e, arg2.data(), s2e, arg3.data(), s3e, result.data(), elem_num); auto stride2 = ss.at(2).strides();
// } return (stride1 == stride2 and stride1[0] == 0);
// else }
// {
// std::cout << "mul_add" << std::endl; return false;
// nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b) }
// __device__ { return a * x + b; });
// }
// }
void mul_add(hipStream_t stream, void mul_add(hipStream_t stream,
const argument& result, const argument& result,
...@@ -97,27 +64,29 @@ void mul_add(hipStream_t stream, ...@@ -97,27 +64,29 @@ void mul_add(hipStream_t stream,
const argument& arg3) const argument& arg3)
{ {
auto sr = result.get_shape(); auto sr = result.get_shape();
auto s1 = arg1.get_shape();
auto s2 = arg2.get_shape();
auto s3 = arg3.get_shape();
auto type = sr.type(); auto type = sr.type();
if(type == sr.type()) std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
ss.push_back(arg3.get_shape());
if(type == shape::half_type and is_bert(ss))
{ {
hip_visit_all(result, arg1, arg2, arg3, sr, s1, s2, s3)( auto elem_num = sr.elements() / 2;
[&](auto r, auto i1, auto i2, auto i3, auto dsr, auto ds1, auto ds2, auto ds3) { auto lens = sr.lens();
__half2* rp = reinterpret_cast<__half2*>(r.data()); int last_dim = lens.back() / 2;
__half2* i1p = reinterpret_cast<__half2*>(i1.data()); auto n_dim = lens.size();
__half2* i2p = reinterpret_cast<__half2*>(i2.data()); int block_size = 1024;
__half2* i3p = reinterpret_cast<__half2*>(i3.data()); int block_num = (elem_num + block_size - 1) / block_size;
gs_launch(stream, sr.elements() / 2)([=](auto i) __device__ { if (n_dim == 2)
auto idx = dsr.multi(i); {
auto idx1 = ds1.index(idx); mul_add_kernel_dim3<<<block_num, block_size>>>(arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num);
auto idx2 = ds2.index(idx); }
auto idx3 = ds3.index(idx); else
rp[i] = __hadd2(__hmul2(i1p[idx1], i2p[idx2]), i3p[idx3]); {
}); int factor = lens[1];
}); mul_add_kernel_dim4<<<block_num, block_size>>>(arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num);
}
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment