Unverified Commit c35920e2 authored by zhangyue's avatar zhangyue Committed by GitHub
Browse files

Merge pull request #400 from InfiniTensor/issue/388-rmsnorm3d-p800

issue390: p800 rmsnorm 3d
parents c9f5f2ec e6dd5604
...@@ -36,13 +36,6 @@ __global__ void rmsnormKernel( ...@@ -36,13 +36,6 @@ __global__ void rmsnormKernel(
sync_cluster(); sync_cluster();
} }
// Instantiate the kernel for different data types and block sizes
INSTANTIATE_RMSNORM_KERNEL(64, float, float, float);
INSTANTIATE_RMSNORM_KERNEL(64, float, bfloat16_t, float);
INSTANTIATE_RMSNORM_KERNEL(64, float, bfloat16_t, bfloat16_t);
INSTANTIATE_RMSNORM_KERNEL(64, float, half, float);
INSTANTIATE_RMSNORM_KERNEL(64, float, half, half);
namespace op::rms_norm::kunlun { namespace op::rms_norm::kunlun {
struct Descriptor::Opaque { struct Descriptor::Opaque {
...@@ -64,13 +57,6 @@ infiniStatus_t Descriptor::create( ...@@ -64,13 +57,6 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result); CHECK_RESULT(result);
auto info = result.take(); auto info = result.take();
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (info.ndim() != 2) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
new Descriptor::Opaque{static_cast<device::kunlun::Handle *>(handle)->internal()}, new Descriptor::Opaque{static_cast<device::kunlun::Handle *>(handle)->internal()},
info, info,
...@@ -82,18 +68,29 @@ infiniStatus_t Descriptor::create( ...@@ -82,18 +68,29 @@ infiniStatus_t Descriptor::create(
template <unsigned int BLOCK_SIZE> template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel( infiniStatus_t launchKernel(
uint32_t batch_size, uint32_t dim, size_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y, void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
const void *x, ptrdiff_t stride_x, const void *x, ptrdiff_t stride_x_batch, ptrdiff_t stride_x_nhead,
const void *w, infiniDtype_t wtype, const void *w, infiniDtype_t wtype,
float epsilon, float epsilon,
kunlunStream_t stream) { kunlunStream_t stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ uint32_t batch_size_ = static_cast<uint32_t>(batch_size);
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, stream>>>( \ uint32_t dim_ = static_cast<uint32_t>(dim);
static_cast<Tdata *>(y), stride_y, \ uint32_t nhead_ = static_cast<uint32_t>(nhead);
static_cast<const Tdata *>(x), stride_x, \ auto stride_x_batch_ = static_cast<int32_t>(stride_x_batch);
static_cast<const Tweight *>(w), dim, epsilon); auto stride_x_nhead_ = static_cast<int32_t>(stride_x_nhead);
auto stride_y_batch_ = static_cast<int32_t>(stride_y_batch);
auto stride_y_nhead_ = static_cast<int32_t>(stride_y_nhead);
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
for (uint32_t i = 0; i < batch_size_; ++i) { \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight> \
<<<nhead_, BLOCK_SIZE, stream>>>( \
(Tdata *)y + i * stride_y_batch_, stride_y_nhead_, \
(const Tdata *)x + i * stride_x_batch_, stride_x_nhead_, \
(const Tweight *)w, dim, epsilon); \
}
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float); LAUNCH_KERNEL(half, half, float);
...@@ -123,13 +120,21 @@ infiniStatus_t Descriptor::calculate( ...@@ -123,13 +120,21 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
} }
auto stride_x = static_cast<int32_t>(_info.x_strides[0]); auto stride_x_batch = _info.x_strides[0];
auto stride_y = static_cast<int32_t>(_info.y_strides[0]); auto stride_x_nhead = _info.x_strides[1];
auto dim = static_cast<uint32_t>(_info.dim()); auto stride_y_batch = _info.y_strides[0];
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]); auto stride_y_nhead = _info.y_strides[1];
auto dim = _info.dim();
auto batch_size = _info.shape[0];
auto nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
kunlunStream_t stream_ = static_cast<kunlunStream_t>(stream);
// launch kernel with different block sizes // launch kernel with different block sizes
CHECK_STATUS(launchKernel<64>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, stream)); CHECK_STATUS(launchKernel<64>(batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
x, stride_x_batch, stride_x_nhead,
w, _info.wtype, _info.epsilon, stream_));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment