Commit d854dbee authored by zhangyue's avatar zhangyue
Browse files

issue/111: Merge branch 'main' of github.com:PanZezhong1725/InfiniCore into...

issue/111: Merge branch 'main' of github.com:PanZezhong1725/InfiniCore into issue/111-rmsnorm-kunlun
parents 23ddc20b a474a6f5
......@@ -18,14 +18,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
nullptr,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
......@@ -24,14 +24,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
......@@ -2,7 +2,7 @@
#define __GEMM_H__
#include "../../operator.h"
#include "blas.h"
#include "info.h"
/**
* # 关于 `DESCRIPTOR(NAMESPACE)` 和 `struct Opaque;` 的说明
......@@ -44,50 +44,50 @@
* 这个宏仅适用于矩阵乘,但这种模式很容易复制到其他算子,以简化和规范算子的声明。
*/
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
infiniDtype_t _dtype; \
MatmulInfo _info; \
\
Descriptor( \
infiniDtype_t dtype, \
MatmulInfo info, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_dtype(dtype), \
_info(info), \
workspace_size(workspace_size_) {} \
\
public: \
size_t workspace_size; \
\
~Descriptor(); \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t c_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc); \
\
infiniStatus_t calculate( \
void *workspace, \
size_t workspace_size, \
void *c, \
float beta, \
const void *a, \
const void *b, \
float alpha, \
void *stream) const; \
}; \
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
infiniDtype_t _dtype; \
MatmulInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
infiniDtype_t dtype, \
MatmulInfo info, \
size_t workspace_size_, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_dtype(dtype), \
_info(info), \
_workspace_size(workspace_size_) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t c_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *c, \
float beta, \
const void *a, \
const void *b, \
float alpha, \
void *stream) const; \
}; \
}
#endif // __GEMM_H__
#ifndef __GEMM_INFO_H__
#define __GEMM_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include <algorithm>
namespace op::gemm {
class BlasMatrix {
BlasMatrix() = default;
public:
size_t ndim;
size_t batch;
ptrdiff_t stride;
size_t rows;
size_t cols;
ptrdiff_t row_stride;
ptrdiff_t col_stride;
static utils::Result<BlasMatrix> create(infiniopTensorDescriptor_t layout) {
BlasMatrix ans;
if (layout->ndim() == 2) {
ans.ndim = 2;
ans.batch = 1;
ans.stride = 0;
ans.rows = layout->dim(0);
ans.cols = layout->dim(1);
ans.row_stride = layout->stride(0);
ans.col_stride = layout->stride(1);
} else if (layout->ndim() == 3) {
ans.ndim = 3;
ans.batch = layout->dim(0);
ans.stride = ans.batch == 1 ? 0 : layout->stride(0);
ans.rows = layout->dim(1);
ans.cols = layout->dim(2);
ans.row_stride = layout->stride(1);
ans.col_stride = layout->stride(2);
} else {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (ans.row_stride != 1 && ans.col_stride != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
return utils::Result<BlasMatrix>(ans);
}
bool match_batch(size_t _batch) const {
return batch == _batch || batch == 1;
}
void transpose() {
std::swap(rows, cols);
std::swap(row_stride, col_stride);
}
ptrdiff_t ld() const {
return row_stride == 1 ? col_stride : row_stride;
}
};
enum class MatrixLayout : char {
COL_MAJOR,
ROW_MAJOR,
};
class MatmulInfo {
MatmulInfo() = default;
public:
BlasMatrix a_matrix;
BlasMatrix b_matrix;
BlasMatrix c_matrix;
size_t m, n, k, batch;
bool is_transed;
static utils::Result<MatmulInfo> create(
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
MatrixLayout layout) {
auto a_matrix = BlasMatrix::create(a_desc);
CHECK_RESULT(a_matrix);
auto b_matrix = BlasMatrix::create(b_desc);
CHECK_RESULT(b_matrix);
auto c_matrix = BlasMatrix::create(c_desc);
CHECK_RESULT(c_matrix);
if (c_matrix->rows != a_matrix->rows || c_matrix->cols != b_matrix->cols || a_matrix->cols != b_matrix->rows) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto batch = c_matrix->batch;
if (!a_matrix->match_batch(batch) || !b_matrix->match_batch(batch)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto is_transed = false;
if ((layout == MatrixLayout::COL_MAJOR && c_matrix->col_stride == 1)
|| (layout == MatrixLayout::ROW_MAJOR && c_matrix->row_stride == 1)) {
c_matrix->transpose();
b_matrix->transpose();
a_matrix->transpose();
std::swap(a_matrix, b_matrix);
is_transed = true;
}
auto m = c_matrix->rows;
auto n = c_matrix->cols;
auto k = a_matrix->cols;
return utils::Result<MatmulInfo>(MatmulInfo{
a_matrix.take(),
b_matrix.take(),
c_matrix.take(),
m,
n,
k,
batch,
is_transed});
}
};
} // namespace op::gemm
#endif // __GEMM_INFO_H__
......@@ -27,14 +27,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::ROW_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::ROW_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
......@@ -25,14 +25,11 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
dtype, info, 0,
dtype, result.take(), 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
......
#ifndef __GEMM_MUSA_H__
#define __GEMM_MUSA_H__
#include "../gemm.h"
DESCRIPTOR(musa)
#endif // __GEMM_MUSA_H__
#include "../../../devices/musa/common_musa.h"
#include "../../../devices/musa/musa_handle.h"
#include "gemm_musa.h"
namespace op::gemm::musa {
struct Descriptor::Opaque {
std::shared_ptr<device::musa::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::musa::Handle *>(handle_);
auto dtype = c_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
infiniStatus_t status;
auto info = MatmulInfo(c_desc, a_desc, b_desc, &status, MatrixLayout::COL_MAJOR);
if (status != INFINI_STATUS_SUCCESS) {
return status;
}
*desc_ptr = new Descriptor(
dtype, info, 0,
new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename Tdata>
infiniStatus_t calculate(
const MatmulInfo &info,
std::shared_ptr<device::musa::Handle::Internal> &_internal,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) {
musaDataType a_type, b_type, c_type;
mublasComputeType_t compute_type;
Tdata alpha_, beta_;
if constexpr (std::is_same<Tdata, half>::value) {
alpha_ = __float2half(alpha);
beta_ = __float2half(beta);
a_type = b_type = c_type = MUSA_R_16F;
compute_type = MUBLAS_COMPUTE_16F;
} else {
alpha_ = alpha;
beta_ = beta;
a_type = b_type = c_type = MUSA_R_32F;
compute_type = MUBLAS_COMPUTE_32F_FAST_TF32;
}
if (info.is_transed) {
std::swap(a, b);
}
auto op_a = info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
CHECK_STATUS(_internal->useMublas(
(musaStream_t)stream,
[&](mublasHandle_t handle) {
CHECK_MUBLAS(
mublasGemmStridedBatchedEx(
handle,
op_a,
op_b,
static_cast<int>(info.m),
static_cast<int>(info.n),
static_cast<int>(info.k),
&alpha_,
a,
a_type,
static_cast<int>(info.a_matrix.ld()),
info.a_matrix.stride,
b,
b_type,
static_cast<int>(info.b_matrix.ld()),
info.b_matrix.stride,
&beta_,
c,
c_type,
static_cast<int>(info.c_matrix.ld()),
info.c_matrix.stride,
static_cast<int>(info.batch),
compute_type,
MUBLAS_GEMM_DEFAULT));
return INFINI_STATUS_SUCCESS;
}));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace,
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return musa::calculate<half>(_info, _opaque->internal, c, beta, a, b, alpha, stream);
case INFINI_DTYPE_F32:
return musa::calculate<float>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::gemm::musa
......@@ -17,6 +17,9 @@
#ifdef ENABLE_METAX_API
#include "maca/gemm_maca.h"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/gemm_musa.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/gemm_kunlun.h"
#endif
......@@ -54,6 +57,10 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
......@@ -70,9 +77,9 @@ infiniopGetGemmWorkspaceSize(
infiniopGemmDescriptor_t desc,
size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspace_size; \
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::gemm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
......@@ -92,6 +99,9 @@ infiniopGetGemmWorkspaceSize(
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, musa);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
......@@ -138,6 +148,9 @@ __C infiniStatus_t infiniopGemm(
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, musa);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
......@@ -174,6 +187,9 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, musa);
#endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
......
......@@ -27,14 +27,12 @@ infiniStatus_t Descriptor::create(
auto dst_strides = y_desc->strides().data();
auto src_strides = x_desc->strides().data();
auto element_size = infiniSizeOf(dtype);
auto meta = utils::RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
if (!meta) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
auto result = utils::RearrangeMeta::create(shape, dst_strides, src_strides, ndim, element_size);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(
std::move(*meta),
result.take(),
nullptr,
handle->device,
handle->device_id);
......
#include "rms_norm_aclnn.h"
#include "../../../devices/ascend/common_ascend.h"
#include <aclnnop/aclnn_rms_norm.h>
namespace op::rms_norm::ascend {
struct Descriptor::Opaque {
mutable aclOpExecutor *executor;
aclnnTensorDescriptor_t y;
aclnnTensorDescriptor_t x;
aclnnTensorDescriptor_t w;
aclnnTensorDescriptor_t rstd;
size_t workspaceSize;
~Opaque() {
delete y;
delete x;
delete w;
delete rstd;
aclDestroyAclOpExecutor(executor);
}
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
size_t workspace_size = 0;
aclOpExecutor *executor = nullptr;
aclnnTensorDescriptor_t y = nullptr;
aclnnTensorDescriptor_t x = nullptr;
aclnnTensorDescriptor_t w = nullptr;
aclnnTensorDescriptor_t rstd = nullptr;
std::vector<int64_t> slice_shape = {static_cast<int64_t>((info.shape)[1])};
auto slice_stride = std::vector<int64_t>(1, 1);
y = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);
x = new aclnnTensorDescriptor(toAclDataType(info.atype), slice_shape, slice_stride);
w = new aclnnTensorDescriptor(w_desc);
// Get AclTensor
aclTensor *ty = y->tensor;
aclTensor *tx = x->tensor;
aclTensor *tw = w->tensor;
// Set rstdDesc
// See: https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha002/apiref/appdevgapi/context/aclnnRmsNorm.md
// rstdTensor cannot set nullptr in aclnn
auto rstd_shape = std::vector<int64_t>(1, 1);
auto rstd_strides = std::vector<int64_t>(1, 1);
rstd = new aclnnTensorDescriptor(toAclDataType(INFINI_DTYPE_F32), rstd_shape, rstd_strides);
aclTensor *trstd = rstd->tensor;
// Get WorkspaceSize and set executor
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor));
aclSetAclOpExecutorRepeatable(executor);
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
size_t all_workspace_size = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType);
*desc_ptr = new Descriptor(
new Opaque{executor, y, x, w, rstd, workspace_size},
std::move(info),
all_workspace_size,
handle_ascend->device, handle_ascend->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
if (workspace_size < workspaceSize()) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto tw = _opaque->w->tensor;
auto tx = _opaque->x->tensor;
auto ty = _opaque->y->tensor;
auto trstd = _opaque->rstd->tensor;
void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize);
auto unit = infiniSizeOf(_info.atype);
AclSetTensorAddr(_opaque->executor, 1, tw, (void *)w);
AclSetTensorAddr(_opaque->executor, 3, trstd, rstdPtr);
for (size_t i = 0; i < (_info.shape)[0]; ++i) {
AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit);
AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit);
CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream));
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::ascend
#ifndef __ACLNN_RMS_NORM_H__
#define __ACLNN_RMS_NORM_H__
#include "../rms_norm.h"
DESCRIPTOR(ascend)
#endif
......@@ -3,6 +3,7 @@
#include "../../../reduce/cpu/reduce.h"
namespace op::rms_norm::cpu {
Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
......@@ -12,9 +13,9 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
RMSNormInfo info;
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon));
*desc_ptr = new Descriptor(nullptr, info, 0, handle->device, handle->device_id);
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -68,9 +69,10 @@ infiniStatus_t rmsnormF16(const RMSNormInfo *info, fp16_t *y, const fp16_t *x, c
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) {
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(rmsnormF16(&_info, (fp16_t *)y, (const fp16_t *)x, (const fp16_t *)w));
......
#include "../../../devices/cuda/cuda_common.cuh"
#include "rms_norm_cuda.cuh"
#include "rms_norm_kernel.cuh"
#include <memory>
#include <stdint.h>
namespace op::rms_norm::cuda {
......@@ -21,8 +19,9 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
RMSNormInfo info;
CHECK_STATUS(createRMSNormInfo(&info, y_desc, x_desc, w_desc, epsilon));
auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
......@@ -31,7 +30,9 @@ infiniStatus_t Descriptor::create(
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::cuda::Handle *>(handle)->internal()},
info, 0, handle->device, handle->device_id);
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -70,8 +71,11 @@ infiniStatus_t launchKernel(
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y, const void *x, const void *w, void *stream) {
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x, const void *w,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
......
#ifndef __RMS_NORM_INFO_H__
#define __RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::rms_norm {
class RMSNormInfo {
RMSNormInfo() = default;
public:
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> x_strides;
size_t ndim() const { return shape.size(); }
size_t dim() const { return shape[ndim() - 1]; }
static utils::Result<RMSNormInfo> create(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = w_desc->dtype();
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = y_desc->shape()[0];
size_t dim = y_desc->shape()[1];
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (w_desc->stride(0) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
return utils::Result<RMSNormInfo>(RMSNormInfo{
wtype,
atype,
epsilon,
y_desc->shape(),
y_desc->strides(),
x_desc->strides(),
});
}
};
} // namespace op::rms_norm
#endif // __RMS_NORM_INFO_H__
......@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API
#include "cuda/rms_norm_cuda.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rms_norm_aclnn.h"
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h"
#endif
......@@ -45,15 +48,8 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
return bangCreateRMSNormDescriptor((BangHandle_t)handle, (RMSNormBangDescriptor_t *)desc_ptr, y_desc, x_desc, w_desc, epsilon);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return aclnnCreateRMSNormDescriptor((AscendHandle_t)handle,
(RMSNormAclnnDescriptor_t *)desc_ptr,
y_desc,
x_desc,
w_desc,
epsilon);
}
#ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......@@ -94,11 +90,8 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
return bangGetRMSNormWorkspaceSize((RMSNormBangDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return aclnnGetRMSNormWorkspaceSize((RMSNormAclnnDescriptor_t)desc,
size);
}
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......@@ -140,16 +133,8 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
return bangRMSNorm((RMSNormBangDescriptor_t)desc, workspace, workspace_size, y, x, w, stream);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return aclnnRMSNorm((RMSNormAclnnDescriptor_t)desc,
workspace,
workspace_size,
y,
x,
w,
stream);
}
#ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......@@ -190,10 +175,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
return bangDestroyRMSNormDescriptor((RMSNormBangDescriptor_t)desc);
}
#endif
#ifdef ENABLE_ASCEND_NPU
case DevAscendNpu: {
return aclnnDestroyRMSNormDescriptor((RMSNormAclnnDescriptor_t)desc);
}
#ifdef ENABLE_ASCEND_API
DESTROY(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
......
#ifndef RMS_NORM_H
#define RMS_NORM_H
#include "../../operator.h"
#include "../../tensor.h"
#include <vector>
struct RMSNormInfo {
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> x_strides;
size_t ndim() { return shape.size(); }
size_t dim() { return shape[ndim() - 1]; }
};
inline infiniStatus_t createRMSNormInfo(RMSNormInfo *info, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t w_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = w_desc->dtype();
if (x_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16) {
if (wtype != INFINI_DTYPE_F16 && wtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
info->wtype = wtype;
info->atype = atype;
info->epsilon = epsilon;
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = y_desc->shape()[0];
size_t dim = y_desc->shape()[1];
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (w_desc->stride(0) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
info->shape = std::move(y_desc->shape());
info->y_strides = std::move(y_desc->strides());
info->x_strides = std::move(x_desc->strides());
return INFINI_STATUS_SUCCESS;
}
#define DESCRIPTOR(NAMESPACE) \
namespace op::rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
RMSNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
RMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
size_t workspaceSize() const { return _workspace_size; } \
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t w_desc, \
float epsilon); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *y, const void *x, const void *w, void *stream); \
}; \
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
RMSNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
RMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t w_desc, \
float epsilon); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *x, \
const void *w, \
void *stream) const; \
}; \
}
#endif // RMS_NORM_H
#include "infinirt_musa.h"
#include "../../utils.h"
#include <musa_runtime.h>
#include <musa_runtime_api.h>
#define CHECK_MUSART(RT_API) CHECK_INTERNAL(RT_API, musaSuccess)
namespace infinirt::musa {
infiniStatus_t getDeviceCount(int *count) {
CHECK_MUSART(musaGetDeviceCount(count));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t setDevice(int device_id) {
CHECK_MUSART(musaSetDevice(device_id));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t deviceSynchronize() {
CHECK_MUSART(musaDeviceSynchronize());
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamCreate(infinirtStream_t *stream_ptr) {
musaStream_t stream;
CHECK_MUSART(musaStreamCreate(&stream));
*stream_ptr = stream;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamDestroy(infinirtStream_t stream) {
CHECK_MUSART(musaStreamDestroy((musaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamSynchronize(infinirtStream_t stream) {
CHECK_MUSART(musaStreamSynchronize((musaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamWaitEvent(infinirtStream_t stream, infinirtEvent_t event) {
CHECK_MUSART(musaStreamWaitEvent((musaStream_t)stream, (musaEvent_t)event));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventCreate(infinirtEvent_t *event_ptr) {
musaEvent_t event;
CHECK_MUSART(musaEventCreate(&event));
*event_ptr = event;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventRecord(infinirtEvent_t event, infinirtStream_t stream) {
CHECK_MUSART(musaEventRecord((musaEvent_t)event, (musaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventQuery(infinirtEvent_t event, infinirtEventStatus_t *status_ptr) {
auto status = musaEventQuery((musaEvent_t)event);
if (status == musaSuccess) {
*status_ptr = INFINIRT_EVENT_COMPLETE;
} else if (status == musaErrorNotReady) {
*status_ptr = INFINIRT_EVENT_NOT_READY;
} else {
CHECK_MUSART(status);
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventSynchronize(infinirtEvent_t event) {
CHECK_MUSART(musaEventSynchronize((musaEvent_t)event));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t eventDestroy(infinirtEvent_t event) {
CHECK_MUSART(musaEventDestroy((musaEvent_t)event));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t mallocDevice(void **p_ptr, size_t size) {
CHECK_MUSART(musaMalloc(p_ptr, size));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t mallocHost(void **p_ptr, size_t size) {
CHECK_MUSART(musaMallocHost(p_ptr, size));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t freeDevice(void *ptr) {
CHECK_MUSART(musaFree(ptr));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t freeHost(void *ptr) {
CHECK_MUSART(musaFreeHost(ptr));
return INFINI_STATUS_SUCCESS;
}
musaMemcpyKind toMusaMemcpyKind(infinirtMemcpyKind_t kind) {
switch (kind) {
case INFINIRT_MEMCPY_H2D:
return musaMemcpyHostToDevice;
case INFINIRT_MEMCPY_D2H:
return musaMemcpyDeviceToHost;
case INFINIRT_MEMCPY_D2D:
return musaMemcpyDeviceToDevice;
case INFINIRT_MEMCPY_H2H:
return musaMemcpyHostToHost;
default:
return musaMemcpyDefault;
}
}
infiniStatus_t memcpy(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind) {
CHECK_MUSART(musaMemcpy(dst, src, size, toMusaMemcpyKind(kind)));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t memcpyAsync(void *dst, const void *src, size_t size, infinirtMemcpyKind_t kind, infinirtStream_t stream) {
CHECK_MUSART(musaMemcpyAsync(dst, src, size, toMusaMemcpyKind(kind), (musaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
return mallocDevice(p_ptr, size);
}
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
return freeDevice(ptr);
}
} // namespace infinirt::musa
#ifndef __INFINIRT_MUSA_H__
#define __INFINIRT_MUSA_H__
#include "../infinirt_impl.h"
namespace infinirt::musa {
#ifdef ENABLE_MOORE_API
INFINIRT_DEVICE_API_IMPL
#else
INFINIRT_DEVICE_API_NOOP
#endif
} // namespace infinirt::musa
#endif // __INFINIRT_MUSA_H__
#ifndef INFINIUTILS_H
#define INFINIUTILS_H
#include "infinicore.h"
#include "utils/check.h"
#include "utils/custom_types.h"
#include "utils/rearrange.h"
......@@ -38,13 +36,13 @@ inline size_t infiniSizeOf(infiniDtype_t dtype) {
return 4;
case INFINI_DTYPE_F64:
return 8;
case INFINI_DTYPE_C8:
return 2;
case INFINI_DTYPE_C16:
return 4;
return 2;
case INFINI_DTYPE_C32:
return 8;
return 4;
case INFINI_DTYPE_C64:
return 8;
case INFINI_DTYPE_C128:
return 16;
case INFINI_DTYPE_BF16:
return 2;
......@@ -85,14 +83,14 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) {
return "F32";
case INFINI_DTYPE_F64:
return "F64";
case INFINI_DTYPE_C8:
return "C8";
case INFINI_DTYPE_C16:
return "C16";
case INFINI_DTYPE_C32:
return "C32";
case INFINI_DTYPE_C64:
return "C64";
case INFINI_DTYPE_C128:
return "C128";
case INFINI_DTYPE_BF16:
return "BF16";
default:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment