Unverified Commit dc8ddd58 authored by Ziminli's avatar Ziminli Committed by GitHub
Browse files

issue/388: Support 3D Cases for RMS Norm

parent 47895fae
...@@ -21,19 +21,27 @@ infiniStatus_t Descriptor::create( ...@@ -21,19 +21,27 @@ infiniStatus_t Descriptor::create(
template <typename T> template <typename T>
infiniStatus_t rmsnorm(const RMSNormInfo *info, T *y, const T *x, const T *w) { infiniStatus_t rmsnorm(const RMSNormInfo *info, T *y, const T *x, const T *w) {
const size_t batch_size = info->shape[0];
const size_t nhead = info->shape.size() > 2 ? info->shape[1] : 1;
const size_t dim = info->shape.back();
const ptrdiff_t total_blocks = static_cast<ptrdiff_t>(batch_size * nhead);
#pragma omp parallel for #pragma omp parallel for
for (ptrdiff_t i = 0; i < ptrdiff_t(info->shape[0]); i++) { for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) {
T *x_ = (T *)(x + i * info->x_strides[0]); const size_t i = block_idx / nhead; // batch index
T *y_ = (T *)(y + i * info->y_strides[0]); const size_t j = block_idx % nhead; // head index
const T *x_ptr = x + i * info->x_strides[0] + j * info->x_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
// [Reduce] sum of x^2 on last dimension // [Reduce] sum of x^2 on last dimension
T ss = op::common_cpu::reduce_op::sumSquared(x_, info->shape[1], info->x_strides[1]); T ss = op::common_cpu::reduce_op::sumSquared(x_ptr, dim, info->x_strides.back());
// 1 / (sqrt(sum/dim + eps)) // 1 / (sqrt(sum/dim + eps))
T rms = (T)1 / std::sqrt(ss / (T)(info->shape[1]) + (T)(info->epsilon)); T rms = (T)1 / std::sqrt(ss / (T)(dim) + (T)(info->epsilon));
for (size_t j = 0; j < info->shape[1]; j++) { for (size_t k = 0; k < dim; k++) {
y_[j * info->y_strides[1]] = x_[j * info->x_strides[1]] * w[j] * rms; y_ptr[k] = x_ptr[k] * w[k] * rms;
} }
} }
...@@ -45,24 +53,32 @@ infiniStatus_t rmsnormHalfPrecision(const RMSNormInfo *info, T *y, const T *x, c ...@@ -45,24 +53,32 @@ infiniStatus_t rmsnormHalfPrecision(const RMSNormInfo *info, T *y, const T *x, c
static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value, static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value,
"T must be fp16_t or bf16_t"); "T must be fp16_t or bf16_t");
const size_t batch_size = info->shape[0];
const size_t nhead = info->shape.size() > 2 ? info->shape[1] : 1;
const size_t dim = info->shape.back();
const ptrdiff_t total_blocks = static_cast<ptrdiff_t>(batch_size * nhead);
#pragma omp parallel for #pragma omp parallel for
for (ptrdiff_t i = 0; i < ptrdiff_t(info->shape[0]); i++) { for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) {
T *x_ = (T *)(x + i * info->x_strides[0]); const size_t i = block_idx / nhead; // batch index
T *y_ = (T *)(y + i * info->y_strides[0]); const size_t j = block_idx % nhead; // head index
const T *x_ptr = x + i * info->x_strides[0] + j * info->x_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
// [Reduce] sum of x^2 on last dimension // [Reduce] sum of x^2 on last dimension
float ss = op::common_cpu::reduce_op::sumSquared(x_, info->shape[1], info->x_strides[1]); float ss = op::common_cpu::reduce_op::sumSquared(x_ptr, dim, info->x_strides.back());
// 1 / (sqrt(sum/dim + eps)) // 1 / (sqrt(sum/dim + eps))
float rms = 1.f / std::sqrt(ss / (float)(info->shape[1]) + info->epsilon); float rms = 1.f / std::sqrt(ss / (float)(dim) + info->epsilon);
for (size_t j = 0; j < info->shape[1]; j++) { for (size_t k = 0; k < dim; k++) {
if constexpr (std::is_same<Tw, float>::value) { if constexpr (std::is_same<Tw, float>::value) {
float val = utils::cast<float>(x_[j * info->x_strides[1]]) * w[j] * rms; float val = utils::cast<float>(x_ptr[k]) * w[k] * rms;
y_[j * info->y_strides[1]] = utils::cast<T>(val); y_ptr[k] = utils::cast<T>(val);
} else if constexpr (std::is_same<Tw, T>::value) { } else if constexpr (std::is_same<Tw, T>::value) {
float val = utils::cast<float>(x_[j * info->x_strides[1]]) * utils::cast<float>(w[j]) * rms; float val = utils::cast<float>(x_ptr[k]) * utils::cast<float>(w[k]) * rms;
y_[j * info->y_strides[1]] = utils::cast<T>(val); y_ptr[k] = utils::cast<T>(val);
} else { } else {
std::abort(); std::abort();
} }
...@@ -93,9 +109,9 @@ infiniStatus_t Descriptor::calculate( ...@@ -93,9 +109,9 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
} else if (_info.atype == INFINI_DTYPE_F32) { } else if (_info.atype == INFINI_DTYPE_F32) {
CHECK_STATUS(rmsnorm(&_info, (float *)y, (float *)x, (float *)w)); CHECK_STATUS(rmsnorm(&_info, (float *)y, (const float *)x, (const float *)w));
} else if (_info.atype == INFINI_DTYPE_F64) { } else if (_info.atype == INFINI_DTYPE_F64) {
CHECK_STATUS(rmsnorm(&_info, (double *)y, (double *)x, (double *)w)); CHECK_STATUS(rmsnorm(&_info, (double *)y, (const double *)x, (const double *)w));
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
...@@ -4,16 +4,22 @@ ...@@ -4,16 +4,22 @@
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight> template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
__device__ void rmsnormBlock( __device__ void rmsnormBlock(
Tdata *__restrict__ y, Tdata *__restrict__ y,
ptrdiff_t stride_y, ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
const Tdata *__restrict__ x, const Tdata *__restrict__ x,
ptrdiff_t stride_x, ptrdiff_t stride_x_batch,
ptrdiff_t stride_x_nhead,
const Tweight *__restrict__ w, const Tweight *__restrict__ w,
size_t nhead,
size_t dim, size_t dim,
float epsilon) { float epsilon) {
// Each block takes care of a row of continuous data of length dim // Each block takes care of one head in one batch
// Each thread deals with every block_size element in the row // Each thread deals with every block_size element in the row
auto y_ptr = y + blockIdx.x * stride_y; size_t batch_idx = blockIdx.x / nhead;
auto x_ptr = x + blockIdx.x * stride_x; size_t head_idx = blockIdx.x % nhead;
auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
auto x_ptr = x + batch_idx * stride_x_batch + head_idx * stride_x_nhead;
auto w_ptr = w; auto w_ptr = w;
// Block-reduce sum of x^2 // Block-reduce sum of x^2
......
...@@ -46,21 +46,39 @@ public: ...@@ -46,21 +46,39 @@ public:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) { const size_t y_ndim = y_desc->ndim();
return INFINI_STATUS_BAD_TENSOR_SHAPE; const size_t x_ndim = x_desc->ndim();
} const size_t w_ndim = w_desc->ndim();
size_t batch = y_desc->shape()[0]; if (y_ndim != x_ndim || w_ndim != 1) {
size_t dim = y_desc->shape()[1];
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
} }
if (w_desc->stride(0) != 1) { size_t batch = 1;
return INFINI_STATUS_BAD_TENSOR_STRIDES; size_t nhead = 1;
size_t dim = 0;
if (y_ndim == 2) {
batch = y_desc->dim(0);
dim = y_desc->dim(1);
if (x_desc->dim(0) != batch || x_desc->dim(1) != dim || w_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else if (y_ndim == 3) {
batch = y_desc->dim(0);
nhead = y_desc->dim(1);
dim = y_desc->dim(2);
if (x_desc->dim(0) != batch || x_desc->dim(1) != nhead || x_desc->dim(2) != dim || w_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
} }
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) { // Check contiguity of the last dimension
if (y_desc->stride(y_ndim - 1) != 1 || x_desc->stride(x_ndim - 1) != 1 || w_desc->stride(w_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES; return INFINI_STATUS_BAD_TENSOR_STRIDES;
} }
......
...@@ -11,13 +11,16 @@ ...@@ -11,13 +11,16 @@
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight> template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_METAX_KERNEL rmsnormKernel( INFINIOP_METAX_KERNEL rmsnormKernel(
Tdata *__restrict__ y, Tdata *__restrict__ y,
ptrdiff_t stride_y, ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
const Tdata *__restrict__ x, const Tdata *__restrict__ x,
ptrdiff_t stride_x, ptrdiff_t stride_x_batch,
ptrdiff_t stride_x_nhead,
const Tweight *__restrict__ w, const Tweight *__restrict__ w,
size_t nhead,
size_t dim, size_t dim,
float epsilon) { float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon); rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, nhead, dim, epsilon);
} }
namespace op::rms_norm::metax { namespace op::rms_norm::metax {
...@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create( ...@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result); CHECK_RESULT(result);
auto info = result.take(); auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()}, new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info), std::move(info),
...@@ -57,21 +55,24 @@ infiniStatus_t Descriptor::create( ...@@ -57,21 +55,24 @@ infiniStatus_t Descriptor::create(
// launch kernel with different data types // launch kernel with different data types
template <unsigned int BLOCK_SIZE> template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel( infiniStatus_t launchKernel(
uint32_t batch_size, size_t dim, uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y, void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
const void *x, ptrdiff_t stride_x, const void *x, ptrdiff_t stride_x_batch, ptrdiff_t stride_x_nhead,
const void *w, infiniDtype_t wtype, const void *w, infiniDtype_t wtype,
float epsilon, float epsilon,
hcStream_t stream) { hcStream_t stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ #define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, stream>>>( \ rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, stream>>>( \
reinterpret_cast<Tdata *>(y), \ reinterpret_cast<Tdata *>(y), \
stride_y, \ stride_y_batch, \
reinterpret_cast<const Tdata *>(x), \ stride_y_nhead, \
stride_x, \ reinterpret_cast<const Tdata *>(x), \
reinterpret_cast<const Tweight *>(w), \ stride_x_batch, \
dim, \ stride_x_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon) epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
...@@ -102,15 +103,18 @@ infiniStatus_t Descriptor::calculate( ...@@ -102,15 +103,18 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
} }
auto stride_x = _info.x_strides[0]; auto stride_x_batch = _info.x_strides[0];
auto stride_y = _info.y_strides[0]; auto stride_x_nhead = _info.x_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto dim = _info.dim(); auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]); uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto stream = reinterpret_cast<hcStream_t>(stream_); auto stream = reinterpret_cast<hcStream_t>(stream_);
// launch kernel with different block sizes // launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, stream)); CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, stream));
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
......
...@@ -11,13 +11,16 @@ ...@@ -11,13 +11,16 @@
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight> template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MOORE_KERNEL rmsnormKernel( INFINIOP_MOORE_KERNEL rmsnormKernel(
Tdata *__restrict__ y, Tdata *__restrict__ y,
ptrdiff_t stride_y, ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
const Tdata *__restrict__ x, const Tdata *__restrict__ x,
ptrdiff_t stride_x, ptrdiff_t stride_x_batch,
ptrdiff_t stride_x_nhead,
const Tweight *__restrict__ w, const Tweight *__restrict__ w,
size_t nhead,
size_t dim, size_t dim,
float epsilon) { float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon); rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, nhead, dim, epsilon);
} }
namespace op::rms_norm::moore { namespace op::rms_norm::moore {
...@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create( ...@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result); CHECK_RESULT(result);
auto info = result.take(); auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()}, new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
std::move(info), std::move(info),
...@@ -57,20 +55,23 @@ infiniStatus_t Descriptor::create( ...@@ -57,20 +55,23 @@ infiniStatus_t Descriptor::create(
// launch kernel with different data types // launch kernel with different data types
template <unsigned int BLOCK_SIZE> template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel( infiniStatus_t launchKernel(
uint32_t batch_size, size_t dim, uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y, void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
const void *x, ptrdiff_t stride_x, const void *x, ptrdiff_t stride_x_batch, ptrdiff_t stride_x_nhead,
const void *w, infiniDtype_t wtype, const void *w, infiniDtype_t wtype,
float epsilon, float epsilon,
musaStream_t musa_stream) { musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ #define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \ rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \ reinterpret_cast<Tdata *>(y), \
stride_y, \ stride_y_batch, \
stride_y_nhead, \
reinterpret_cast<const Tdata *>(x), \ reinterpret_cast<const Tdata *>(x), \
stride_x, \ stride_x_batch, \
stride_x_nhead, \
reinterpret_cast<const Tweight *>(w), \ reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \ dim, \
epsilon) epsilon)
...@@ -102,19 +103,22 @@ infiniStatus_t Descriptor::calculate( ...@@ -102,19 +103,22 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
} }
auto stride_x = _info.x_strides[0]; auto stride_x_batch = _info.x_strides[0];
auto stride_y = _info.y_strides[0]; auto stride_x_nhead = _info.x_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto dim = _info.dim(); auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]); uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto musa_stream = reinterpret_cast<musaStream_t>(stream); auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes // launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream)); CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) { } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream)); CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) { } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream)); CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, musa_stream));
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
......
...@@ -11,13 +11,16 @@ ...@@ -11,13 +11,16 @@
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight> template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_CUDA_KERNEL rmsnormKernel( INFINIOP_CUDA_KERNEL rmsnormKernel(
Tdata *__restrict__ y, Tdata *__restrict__ y,
ptrdiff_t stride_y, ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
const Tdata *__restrict__ x, const Tdata *__restrict__ x,
ptrdiff_t stride_x, ptrdiff_t stride_x_batch,
ptrdiff_t stride_x_nhead,
const Tweight *__restrict__ w, const Tweight *__restrict__ w,
size_t nhead,
size_t dim, size_t dim,
float epsilon) { float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon); rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, nhead, dim, epsilon);
} }
namespace op::rms_norm::nvidia { namespace op::rms_norm::nvidia {
...@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create( ...@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result); CHECK_RESULT(result);
auto info = result.take(); auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()}, new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
std::move(info), std::move(info),
...@@ -57,21 +55,24 @@ infiniStatus_t Descriptor::create( ...@@ -57,21 +55,24 @@ infiniStatus_t Descriptor::create(
// launch kernel with different data types // launch kernel with different data types
template <unsigned int BLOCK_SIZE> template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel( infiniStatus_t launchKernel(
uint32_t batch_size, size_t dim, uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y, void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
const void *x, ptrdiff_t stride_x, const void *x, ptrdiff_t stride_x_batch, ptrdiff_t stride_x_nhead,
const void *w, infiniDtype_t wtype, const void *w, infiniDtype_t wtype,
float epsilon, float epsilon,
cudaStream_t cuda_stream) { cudaStream_t cuda_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ #define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, cuda_stream>>>( \ rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \ reinterpret_cast<Tdata *>(y), \
stride_y, \ stride_y_batch, \
reinterpret_cast<const Tdata *>(x), \ stride_y_nhead, \
stride_x, \ reinterpret_cast<const Tdata *>(x), \
reinterpret_cast<const Tweight *>(w), \ stride_x_batch, \
dim, \ stride_x_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon) epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
...@@ -102,19 +103,22 @@ infiniStatus_t Descriptor::calculate( ...@@ -102,19 +103,22 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
} }
auto stride_x = _info.x_strides[0]; auto stride_x_batch = _info.x_strides[0];
auto stride_y = _info.y_strides[0]; auto stride_x_nhead = _info.x_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto dim = _info.dim(); auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]); uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream); auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
// launch kernel with different block sizes // launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, cuda_stream)); CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, cuda_stream)); CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_4096>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, cuda_stream)); CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_4096>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
......
...@@ -97,6 +97,10 @@ if __name__ == "__main__": ...@@ -97,6 +97,10 @@ if __name__ == "__main__":
((500, 4096), None, (8192, 1)), ((500, 4096), None, (8192, 1)),
((4, 512), (1024, 1), (512, 1)), ((4, 512), (1024, 1), (512, 1)),
((4, 512), None, (2048, 1)), ((4, 512), None, (2048, 1)),
((3, 4, 512), None, None),
((3, 4, 512), None, (4096, 1024, 1)),
((3, 4, 512), (4096, 1024, 1), None),
((3, 4, 512), (4096, 1024, 1), (4096, 1024, 1)),
] ]
_TENSOR_DTYPES_ = [np.float32, np.float16] _TENSOR_DTYPES_ = [np.float32, np.float16]
for dtype in _TENSOR_DTYPES_: for dtype in _TENSOR_DTYPES_:
......
...@@ -25,11 +25,14 @@ from libinfiniop import ( ...@@ -25,11 +25,14 @@ from libinfiniop import (
_TEST_CASES_ = [ _TEST_CASES_ = [
# y_shape, x_shape, w_shape, y_stride, x_stride # y_shape, x_shape, w_shape, y_stride, x_stride
((1, 4), (1, 4), (4,), None, None), ((1, 4), (1, 4), (4,), None, None),
((1, 4), (1, 4), (4,), None, None), ((2, 4), (2, 4), (4,), None, None),
((16, 2048), (16, 2048), (2048,), None, None), ((2, 2, 4), (2, 2, 4), (4,), None, None),
((2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1)),
((16, 2048), (16, 2048), (2048,), None, None), ((16, 2048), (16, 2048), (2048,), None, None),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)), ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)), ((4, 4, 2048), (4, 4, 2048), (2048,), None, None),
((4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1)),
] ]
# w (weight) types # w (weight) types
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment