Commit 2d5e45b8 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup kernel refinement for add, mul, and mul_add

parent 16e5b5d0
......@@ -8,27 +8,44 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
__global__ void add_kernel(__half* a, __half* b, __half* r, int n)
static bool is_bert(const std::vector<shape>& ss)
{
auto n_dim = ss.front().lens().size();
if(n_dim == 2)
{
auto stride = ss.at(1).strides();
return (stride[0] == 0);
}
return false;
}
__global__ void add_kernel(void* a, void* b, int n_dim, void* r, int n)
{
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
r[tid] = a[tid] + b[tid % 768];
int idb = tid % n_dim;
hr[tid] = __hadd2(ha[tid], hb[idb]);
}
}
void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
auto s2 = arg2.get_shape();
if(s2.element_space() == 768 and s2.type() == shape::half_type)
auto sr = result.get_shape();
std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
if(sr.type() == shape::half_type and is_bert(ss))
{
auto elem_num = s2.elements();
auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
add_kernel<<<block_num, block_size>>>(reinterpret_cast<__half*>(arg1.data()),
reinterpret_cast<__half*>(arg2.data()),
reinterpret_cast<__half*>(result.data()),
elem_num);
add_kernel<<<block_num, block_size>>>(arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
{
......
......@@ -8,44 +8,51 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
__global__ void mul_kernel(__half* a, __half* b, __half* r, int n)
static bool is_bert(const std::vector<shape>& ss)
{
auto n_dim = ss.front().lens().size();
if(n_dim == 2)
{
auto stride = ss.at(1).strides();
return (stride[0] == 0);
}
return false;
}
__global__ void mul_kernel(void* a, void* b, int n_dim, void* r, int n)
{
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hr = reinterpret_cast<__half2*>(r);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid < n)
{
r[tid] = a[tid] * b[tid % 768];
int idb = tid % n_dim;
hr[tid] = __hmul2(ha[tid], hb[idb]);
}
}
void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
auto s2 = arg2.get_shape();
if(s2.element_space() == 768 and s2.type() == shape::half_type)
auto sr = result.get_shape();
std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
if(sr.type() == shape::half_type and is_bert(ss))
{
auto elem_num = s2.elements();
auto elem_num = sr.elements() / 2;
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
mul_kernel<<<block_num, block_size>>>(reinterpret_cast<__half*>(arg1.data()),
reinterpret_cast<__half*>(arg2.data()),
reinterpret_cast<__half*>(result.data()),
elem_num);
mul_kernel<<<block_num, block_size>>>(arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x * y; });
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ { return x + y; });
}
}
void mul(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z)
__device__ { return x * y * z; });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -11,84 +11,51 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
//__global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int n)
//{
// int id = blockDim.x * blockIdx.x + threadIdx.x;
// __half* ha = reinterpret_cast<__half*>(a);
// __half* hb = reinterpret_cast<__half*>(b);
// __half* hx = reinterpret_cast<__half*>(x);
// __half* hr = reinterpret_cast<__half*>(r);
// if (id < n)
// {
// hr[id] = __float2half(__half2float(ha[id]) * __half2float(hx[id]) + __half2float(hb[id]));
// }
//}
// __global__ void mul_add_kernel(void* a, int an, void* x, int xn, void* b, int bn, void* r, int n)
// {
// int id = blockDim.x * blockIdx.x + threadIdx.x;
// __half2* ha = reinterpret_cast<__half2*>(a);
// __half2* hb = reinterpret_cast<__half2*>(b);
// __half2* hx = reinterpret_cast<__half2*>(x);
// __half2* hr = reinterpret_cast<__half2*>(r);
// if(id < n)
// {
// hr[id] = __hadd2(__hmul2(ha[id % an], hx[id % xn]), hb[id % bn]);
// }
// }
__global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides, int elem_num)
__global__ void mul_add_kernel_dim3(void* a, void* x, void* b, int dim3, void* r, int n)
{
__shared__ int shared_strides[18];
int tid = threadIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z;
if(tid < 18)
int id = blockDim.x * blockIdx.x + threadIdx.x;
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hx = reinterpret_cast<__half2*>(x);
__half2* hr = reinterpret_cast<__half2*>(r);
if(id < n)
{
shared_strides[tid] = strides[tid];
auto id1 = id % dim3;
hr[id] = __hadd2(__hmul2(ha[id], hx[id1]), hb[id1]);
}
__syncthreads();
}
__global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int dim4, void* r, int n)
{
int id = blockDim.x * blockIdx.x + threadIdx.x;
__half2* ha = reinterpret_cast<__half2*>(a);
__half2* hb = reinterpret_cast<__half2*>(b);
__half2* hx = reinterpret_cast<__half2*>(x);
__half2* hr = reinterpret_cast<__half2*>(r);
tid = tid + (blockIdx.x * (gridDim.y * gridDim.z) + blockIdx.y * gridDim.z + blockIdx.z) *
blockDim.x * blockDim.y * blockDim.z;
if(tid < elem_num)
if(id < n)
{
int tida = shared_strides[1] * blockIdx.x + shared_strides[2] * blockIdx.y +
shared_strides[3] * blockIdx.z + shared_strides[4] * threadIdx.x +
shared_strides[5] * threadIdx.y + threadIdx.z;
int tidx = shared_strides[7] * blockIdx.x + shared_strides[8] * blockIdx.y +
shared_strides[9] * blockIdx.z + shared_strides[10] * threadIdx.x +
shared_strides[11] * threadIdx.y + threadIdx.z;
int tidb = shared_strides[13] * blockIdx.x + shared_strides[14] * blockIdx.y +
shared_strides[15] * blockIdx.z + shared_strides[16] * threadIdx.x +
shared_strides[17] * threadIdx.y + threadIdx.z;
hr[tid] = __hadd2(__hmul2(ha[tida], hx[tidx]), hb[tidb]);
int idb = id / factor + id % dim4;
hr[id] = __hadd2(__hmul2(ha[id], hx[id]), hb[idb]);
}
}
// void mul_add(hipStream_t stream,
// const argument& result,
// const argument& arg1,
// const argument& arg2,
// const argument& arg3)
// {
// auto type = result.get_shape().type();
// if(type == shape::half_type)
// {
// std::cout << "case1" << std::endl;
// mul_add_kernel<<<block_num, block_size>>>(
// arg1.data(), s1e, arg2.data(), s2e, arg3.data(), s3e, result.data(), elem_num);
// }
// else
// {
// std::cout << "mul_add" << std::endl;
// nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b)
// __device__ { return a * x + b; });
// }
// }
static bool is_bert(const std::vector<shape>& ss)
{
auto n_dim = ss.front().lens().size();
if(n_dim == 3)
{
auto stride = ss.at(2).strides();
return (stride[1] == 0);
}
else if(n_dim == 2)
{
auto stride1 = ss.at(1).strides();
auto stride2 = ss.at(2).strides();
return (stride1 == stride2 and stride1[0] == 0);
}
return false;
}
void mul_add(hipStream_t stream,
const argument& result,
......@@ -97,27 +64,29 @@ void mul_add(hipStream_t stream,
const argument& arg3)
{
auto sr = result.get_shape();
auto s1 = arg1.get_shape();
auto s2 = arg2.get_shape();
auto s3 = arg3.get_shape();
auto type = sr.type();
if(type == sr.type())
std::vector<shape> ss;
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
ss.push_back(arg3.get_shape());
if(type == shape::half_type and is_bert(ss))
{
hip_visit_all(result, arg1, arg2, arg3, sr, s1, s2, s3)(
[&](auto r, auto i1, auto i2, auto i3, auto dsr, auto ds1, auto ds2, auto ds3) {
__half2* rp = reinterpret_cast<__half2*>(r.data());
__half2* i1p = reinterpret_cast<__half2*>(i1.data());
__half2* i2p = reinterpret_cast<__half2*>(i2.data());
__half2* i3p = reinterpret_cast<__half2*>(i3.data());
gs_launch(stream, sr.elements() / 2)([=](auto i) __device__ {
auto idx = dsr.multi(i);
auto idx1 = ds1.index(idx);
auto idx2 = ds2.index(idx);
auto idx3 = ds3.index(idx);
rp[i] = __hadd2(__hmul2(i1p[idx1], i2p[idx2]), i3p[idx3]);
});
});
auto elem_num = sr.elements() / 2;
auto lens = sr.lens();
int last_dim = lens.back() / 2;
auto n_dim = lens.size();
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
if (n_dim == 2)
{
mul_add_kernel_dim3<<<block_num, block_size>>>(arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num);
}
else
{
int factor = lens[1];
mul_add_kernel_dim4<<<block_num, block_size>>>(arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num);
}
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment