Commit 1ce84cf5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix kernels related to add, mul, and mul_add

parent 87cd03e0
......@@ -45,7 +45,7 @@ void add(hipStream_t stream, const argument& result, const argument& arg1, const
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
add_kernel<<<block_num, block_size>>>(
add_kernel<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
......
......@@ -45,7 +45,7 @@ void mul(hipStream_t stream, const argument& result, const argument& arg1, const
auto last_dim = sr.lens().back() / 2;
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
mul_kernel<<<block_num, block_size>>>(
mul_kernel<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), last_dim, result.data(), elem_num);
}
else
......
......@@ -34,7 +34,7 @@ __global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int d
__half2* hr = reinterpret_cast<__half2*>(r);
if(id < n)
{
int idb = id / factor + id % dim4;
int idb = id / (factor * dim4) * dim4 + id % dim4;
hr[id] = __hadd2(__hmul2(ha[id], hx[id]), hb[idb]);
}
}
......@@ -70,23 +70,23 @@ void mul_add(hipStream_t stream,
ss.push_back(arg1.get_shape());
ss.push_back(arg2.get_shape());
ss.push_back(arg3.get_shape());
auto lens = sr.lens();
int last_dim = lens.back() / 2;
auto n_dim = lens.size();
if(type == shape::half_type and is_bert(ss))
{
auto elem_num = sr.elements() / 2;
auto lens = sr.lens();
int last_dim = lens.back() / 2;
auto n_dim = lens.size();
int block_size = 1024;
int block_num = (elem_num + block_size - 1) / block_size;
if(n_dim == 2)
{
mul_add_kernel_dim3<<<block_num, block_size>>>(
mul_add_kernel_dim3<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), arg3.data(), last_dim, result.data(), elem_num);
}
else
{
int factor = lens[1];
mul_add_kernel_dim4<<<block_num, block_size>>>(
mul_add_kernel_dim4<<<block_num, block_size, 0, stream>>>(
arg1.data(), arg2.data(), arg3.data(), factor, last_dim, result.data(), elem_num);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment