Commit 7d60e5b8 authored by zhuyue's avatar zhuyue Committed by gongchensu
Browse files

增加cpu的add rms_norm算子,c++和python接口

parent 12cde8eb
#pragma once #pragma once
#include "ops/add.hpp" #include "ops/add.hpp"
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp" #include "ops/attention.hpp"
#include "ops/causal_softmax.hpp" #include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp" #include "ops/matmul.hpp"
......
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class AddRMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
} // namespace infinicore::op
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "infiniop/handle.h" #include "infiniop/handle.h"
#include "infiniop/ops/add.h" #include "infiniop/ops/add.h"
#include "infiniop/ops/add_rms_norm.h"
#include "infiniop/ops/attention.h" #include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/clip.h" #include "infiniop/ops/clip.h"
......
#ifndef __INFINIOP_ADD_RMS_NORM_API_H__
#define __INFINIOP_ADD_RMS_NORM_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopAddRMSNormDescriptor_t;
__C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon);
__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *a,
const void *b,
const void *weight,
void *stream);
__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);
#endif
...@@ -40,6 +40,7 @@ from infinicore.dtype import ( ...@@ -40,6 +40,7 @@ from infinicore.dtype import (
uint8, uint8,
) )
from infinicore.ops.add import add from infinicore.ops.add import add
from infinicore.ops.add_rms_norm import add_rms_norm
from infinicore.ops.attention import attention from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul from infinicore.ops.mul import mul
...@@ -102,6 +103,7 @@ __all__ = [ ...@@ -102,6 +103,7 @@ __all__ = [
"uint8", "uint8",
# Operations. # Operations.
"add", "add",
"add_rms_norm",
"attention", "attention",
"matmul", "matmul",
"mul", "mul",
......
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
if out is None:
return Tensor(_infinicore.add_rms_norm(a._underlying, b._underlying, weight._underlying, epsilon))
_infinicore.add_rms_norm_(out._underlying, a._underlying, b._underlying, weight._underlying, epsilon)
return out
#include "infinicore/ops/add_rms_norm.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() {
static common::OpDispatcher<AddRMSNorm::schema> dispatcher_;
return dispatcher_;
};
void AddRMSNorm::execute(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, a, b, weight);
infinicore::context::setDevice(y->device());
dispatcher().lookup(y->device().getType())(y, a, b, weight, epsilon);
}
Tensor add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) {
auto y = Tensor::empty(a->shape(), a->dtype(), a->device());
add_rms_norm_(y, a, b, weight, epsilon);
return y;
}
void add_rms_norm_(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
AddRMSNorm::execute(y, a, b, weight, epsilon);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add_rms_norm.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>
namespace infinicore::op::add_rms_norm_impl::infiniop {
thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches(
100, // capacity
[](infiniopAddRMSNormDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
size_t seed = hash_combine(y, a, b, weight, epsilon);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopAddRMSNormDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor(
context::getInfiniopHandle(device), &desc,
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopAddRMSNorm(
desc, workspace->data(), workspace_size,
y->data(), a->data(), b->data(), weight->data(), context::getStream()));
}
static bool registered = []() {
AddRMSNorm::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::add_rms_norm_impl::infiniop
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "ops/add.hpp" #include "ops/add.hpp"
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp" #include "ops/attention.hpp"
#include "ops/causal_softmax.hpp" #include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp" #include "ops/embedding.hpp"
...@@ -22,6 +23,7 @@ namespace infinicore::ops { ...@@ -22,6 +23,7 @@ namespace infinicore::ops {
inline void bind(py::module &m) { inline void bind(py::module &m) {
bind_add(m); bind_add(m);
bind_add_rms_norm(m);
bind_attention(m); bind_attention(m);
bind_causal_softmax(m); bind_causal_softmax(m);
bind_random_sample(m); bind_random_sample(m);
......
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/add_rms_norm.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_add_rms_norm(py::module &m) {
m.def("add_rms_norm",
&op::add_rms_norm,
py::arg("a"),
py::arg("b"),
py::arg("weight"),
py::arg("epsilon") = 1e-5f,
R"doc(Fused Add and RMS Normalization.
Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
Returns:
Normalized tensor: RMSNorm(a + b) * weight
)doc");
m.def("add_rms_norm_",
&op::add_rms_norm_,
py::arg("y"),
py::arg("a"),
py::arg("b"),
py::arg("weight"),
py::arg("epsilon") = 1e-5f,
R"doc(In-place Fused Add and RMS Normalization.
Args:
y: Output tensor
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
)doc");
}
} // namespace infinicore::ops
#ifndef ADD_RMS_NORM_H
#define ADD_RMS_NORM_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::add_rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
AddRMSNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
AddRMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *a, \
const void *b, \
const void *weight, \
void *stream) const; \
}; \
}
#endif // ADD_RMS_NORM_H
#include "add_rms_norm_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
namespace op::add_rms_norm::cpu {
Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename T>
infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w) {
const size_t batch_size = info->shape[0];
const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1;
const size_t dim = info->dim();
const ptrdiff_t total_blocks = static_cast<ptrdiff_t>(batch_size * nhead);
#pragma omp parallel for
for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) {
const size_t i = block_idx / nhead; // batch index
const size_t j = block_idx % nhead; // head index
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
// First, compute add(a, b) and store sum values
// We'll compute RMS norm directly on the sum
T sum_squared = (T)0;
for (size_t k = 0; k < dim; k++) {
T sum_val = a_ptr[k] + b_ptr[k];
sum_squared += sum_val * sum_val;
}
// Compute RMS: 1 / (sqrt(mean(sum^2) + eps))
// Note: mean = sum_squared / dim
T rms = (T)1 / std::sqrt(sum_squared / (T)(dim) + (T)(info->epsilon));
// Apply normalization: y = (a + b) * w * rms
// Recompute sum to avoid storing temporary array
for (size_t k = 0; k < dim; k++) {
T sum_val = a_ptr[k] + b_ptr[k];
y_ptr[k] = sum_val * w[k] * rms;
}
}
return INFINI_STATUS_SUCCESS;
}
template <typename T, typename Tw>
infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w) {
static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value,
"T must be fp16_t or bf16_t");
const size_t batch_size = info->shape[0];
const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1;
const size_t dim = info->dim();
const ptrdiff_t total_blocks = static_cast<ptrdiff_t>(batch_size * nhead);
#pragma omp parallel for
for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) {
const size_t i = block_idx / nhead; // batch index
const size_t j = block_idx % nhead; // head index
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
// Compute sum of squares for RMS normalization
float sum_squared = 0.0f;
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
sum_squared += sum_val * sum_val;
}
// Compute RMS: 1 / (sqrt(sum/dim + eps))
float rms = 1.f / std::sqrt(sum_squared / (float)(dim) + info->epsilon);
// Apply normalization: y = (a + b) * w * rms
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
float val;
if constexpr (std::is_same<Tw, float>::value) {
val = sum_val * w[k] * rms;
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
val = sum_val * utils::cast<float>(w[k]) * rms;
} else {
std::abort();
}
y_ptr[k] = utils::cast<T>(val);
}
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight));
} else if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_BF16) {
if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight));
} else if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight));
} else if (_info.atype == INFINI_DTYPE_F64) {
CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::cpu
#ifndef __ADD_RMS_NORM_CPU_H__
#define __ADD_RMS_NORM_CPU_H__
#include "../add_rms_norm.h"
DESCRIPTOR(cpu)
#endif
#ifndef __ADD_RMS_NORM_INFO_H__
#define __ADD_RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::add_rms_norm {
class AddRMSNormInfo {
AddRMSNormInfo() = default;
public:
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> a_strides;
std::vector<ptrdiff_t> b_strides;
size_t ndim() const { return shape.size(); }
size_t dim() const { return shape[ndim() - 1]; }
static utils::Result<AddRMSNormInfo> create(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = weight_desc->dtype();
// Check that all input tensors have the same dtype
if (a_desc->dtype() != atype || b_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16 || atype == INFINI_DTYPE_BF16) {
// For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32
if (wtype != atype && wtype != INFINI_DTYPE_F32 && wtype != INFINI_DTYPE_BF16 && wtype != INFINI_DTYPE_F16) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
// For FP32/FP64, activations and weights must be of the same type
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
const size_t y_ndim = y_desc->ndim();
const size_t a_ndim = a_desc->ndim();
const size_t b_ndim = b_desc->ndim();
const size_t w_ndim = weight_desc->ndim();
if (y_ndim != a_ndim || y_ndim != b_ndim || w_ndim != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = 1;
size_t nhead = 1;
size_t dim = 0;
if (y_ndim == 2) {
batch = y_desc->dim(0);
dim = y_desc->dim(1);
if (a_desc->dim(0) != batch || a_desc->dim(1) != dim ||
b_desc->dim(0) != batch || b_desc->dim(1) != dim ||
weight_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else if (y_ndim == 3) {
batch = y_desc->dim(0);
nhead = y_desc->dim(1);
dim = y_desc->dim(2);
if (a_desc->dim(0) != batch || a_desc->dim(1) != nhead || a_desc->dim(2) != dim ||
b_desc->dim(0) != batch || b_desc->dim(1) != nhead || b_desc->dim(2) != dim ||
weight_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Check contiguity of the last dimension
if (y_desc->stride(y_ndim - 1) != 1 ||
a_desc->stride(a_ndim - 1) != 1 ||
b_desc->stride(b_ndim - 1) != 1 ||
weight_desc->stride(w_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
AddRMSNormInfo info;
info.wtype = wtype;
info.atype = atype;
info.epsilon = epsilon;
info.shape = y_desc->shape();
info.y_strides = y_desc->strides();
info.a_strides = a_desc->strides();
info.b_strides = b_desc->strides();
return utils::Result<AddRMSNormInfo>(info);
}
};
} // namespace op::add_rms_norm
#endif // __ADD_RMS_NORM_INFO_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/add_rms_norm.h"
#ifdef ENABLE_CPU_API
#include "cpu/add_rms_norm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
// TODO: Add NVIDIA implementation
// #include "nvidia/add_rms_norm_nvidia.cuh"
#endif
#ifdef ENABLE_ASCEND_API
// TODO: Add Ascend implementation
// #include "ascend/add_rms_norm_aclnn.h"
#endif
#ifdef ENABLE_CAMBRICON_API
// TODO: Add Cambricon implementation
// #include "bang/add_rms_norm_bang.h"
#endif
#ifdef ENABLE_METAX_API
// TODO: Add Metax implementation
// #include "metax/add_rms_norm_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
// TODO: Add Moore implementation
// #include "moore/add_rms_norm_moore.h"
#endif
#ifdef ENABLE_KUNLUN_API
// TODO: Add Kunlun implementation
// #include "kunlun/add_rms_norm_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::add_rms_norm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
a_desc, \
b_desc, \
weight_desc, \
epsilon)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
// CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
// GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// GET(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopAddRMSNorm(
infiniopAddRMSNormDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *a,
const void *b,
const void *weight,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, a, b, weight, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
// CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc) {
if (desc == nullptr) {
return INFINI_STATUS_SUCCESS;
}
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DESTROY(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
// DESTROY(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// DESTROY(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// DESTROY(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
}
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (y_shape, a_shape, b_shape, w_shape, y_strides, a_strides, b_strides)
_TEST_CASES_DATA = [
# Basic cases
((1, 4), (1, 4), (1, 4), (4,), None, None, None),
((2, 4), (2, 4), (2, 4), (4,), None, None, None),
((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), None, None, None),
# Strided cases
((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1), (12, 8, 1)),
# Large tensors
((16, 2048), (16, 2048), (16, 2048), (2048,), None, None, None),
((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)),
]
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 2e-3, "rtol": 2e-3},
infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-4},
}
# Data types for individual tensors
_INPUT_DTYPES = [infinicore.float16, infinicore.bfloat16]
_WEIGHT_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# EPSILON constant for AddRMSNorm
_EPSILON = 1e-5
def parse_test_cases():
"""
Parse AddRMSNorm test case data and return list of TestCase objects.
Format: (y_shape, a_shape, b_shape, w_shape, y_strides, a_strides, b_strides)
"""
test_cases = []
for data in _TEST_CASES_DATA:
y_shape = data[0] # Output shape
a_shape = data[1] # First input shape
b_shape = data[2] # Second input shape
w_shape = data[3] # Weight shape (1D)
y_strides = data[4] if len(data) > 4 else None
a_strides = data[5] if len(data) > 5 else None
b_strides = data[6] if len(data) > 6 else None
# Check if tensors support in-place operations
a_supports_inplace = not is_broadcast(a_strides)
b_supports_inplace = not is_broadcast(b_strides)
y_supports_inplace = not is_broadcast(y_strides)
# Generate test cases for all dtype combinations
for input_dtype in _INPUT_DTYPES:
for weight_dtype in _WEIGHT_DTYPES:
# Use input dtype tolerance for output
tolerance = _TOLERANCE_MAP.get(
input_dtype, {"atol": 1e-5, "rtol": 1e-4}
)
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
b_spec = TensorSpec.from_tensor(b_shape, b_strides, input_dtype)
w_spec = TensorSpec.from_tensor(
w_shape, None, weight_dtype
) # Weight is always contiguous
y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"AddRMSNorm - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (add_rms_norm(a, b, w, out=y))
if y_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON},
output_spec=y_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"AddRMSNorm - INPLACE(out)",
)
)
# Test Case 3: In-place on first input (add_rms_norm(a, b, w, out=a))
if a_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={
"out": 0,
"epsilon": _EPSILON,
}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"AddRMSNorm - INPLACE(a)",
)
)
# Test Case 4: In-place on second input (add_rms_norm(a, b, w, out=b))
if b_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={
"out": 1,
"epsilon": _EPSILON,
}, # Use index 1 for second input
output_spec=None,
comparison_target=1, # Compare second input
tolerance=tolerance,
description=f"AddRMSNorm - INPLACE(b)",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""AddRMSNorm operator test with simplified implementation"""
def __init__(self):
super().__init__("AddRMSNorm")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
"""PyTorch AddRMSNorm implementation"""
input_dtype = a.dtype
# Compute add(a, b)
sum_tensor = a.to(torch.float32) + b.to(torch.float32)
weight_fp32 = weight.to(torch.float32)
# Calculate RMSNorm: (a + b) * weight / sqrt(mean((a+b)^2) + epsilon)
variance = sum_tensor.pow(2).mean(-1, keepdim=True)
result = sum_tensor * torch.rsqrt(variance + epsilon) * weight_fp32
# Convert back to original dtype
result = result.to(input_dtype)
if out is not None:
out.copy_(result)
return out
return result
def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
"""InfiniCore AddRMSNorm implementation"""
return infinicore.add_rms_norm(a, b, weight, epsilon, out=out)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# y_shape, a_shape, b_shape, w_shape, y_stride, a_stride, b_stride
((1, 4), (1, 4), (1, 4), (4,), None, None, None),
((2, 4), (2, 4), (2, 4), (4,), None, None, None),
((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), None, None, None),
((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1), (12, 8, 1)),
((16, 2048), (16, 2048), (16, 2048), (2048,), None, None, None),
((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((15, 8192), (15, 8192), (15, 8192), (8192,), None, None, None),
]
# w (weight) types
# Note: 'None' means the same as input dtype
_WEIGHT_DTYPES = [None, InfiniDtype.F32, InfiniDtype.F16, InfiniDtype.BF16]
# a, b types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16]
# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES = [
test_case + (w_dtype,) for test_case in _TEST_CASES_ for w_dtype in _WEIGHT_DTYPES
]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3},
InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def add_rms_norm(ans, a, b, w, eps):
input_dtype = a.dtype
# Compute add(a, b)
sum_tensor = a.to(torch.float32) + b.to(torch.float32)
# Compute RMS normalization
scale = sum_tensor.pow(2).mean(-1, keepdim=True).add_(eps).rsqrt_()
ans.set_((sum_tensor.mul_(scale).mul_(w.to(torch.float32))).to(input_dtype))
def test(
handle,
device,
y_shape,
a_shape,
b_shape,
w_shape,
y_stride,
a_stride,
b_stride,
w_dtype=InfiniDtype.F32,
dtype=InfiniDtype.F16,
sync=None,
):
w_dtype = w_dtype if w_dtype else dtype
print(
f"Testing AddRMSNorm on {InfiniDeviceNames[device]} with y_shape:{y_shape} a_shape:{a_shape} b_shape:{b_shape} w_shape:{w_shape}"
f" y_stride:{y_stride} a_stride:{a_stride} b_stride:{b_stride} w_dtype:{InfiniDtypeNames[w_dtype]} dtype:{InfiniDtypeNames[dtype]}"
)
y = TestTensor(y_shape, y_stride, dtype, device, mode="ones")
a = TestTensor(a_shape, a_stride, dtype, device, scale=0.01)
b = TestTensor(b_shape, b_stride, dtype, device, scale=0.01)
w = TestTensor(w_shape, None, w_dtype, device)
eps = 1e-6
add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps)
if sync is not None:
sync()
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreateAddRMSNormDescriptor(
handle,
ctypes.byref(descriptor),
y.descriptor,
a.descriptor,
b.descriptor,
w.descriptor,
eps,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a, b, y, w]:
tensor.destroy_desc()
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetAddRMSNormWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, y.device)
def lib_add_rms_norm():
check_error(
LIBINFINIOP.infiniopAddRMSNorm(
descriptor,
workspace.data(),
workspace_size.value,
y.data(),
a.data(),
b.data(),
w.data(),
None,
)
)
lib_add_rms_norm()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_add_rms_norm(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(LIBINFINIOP.infiniopDestroyAddRMSNormDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
# Execute tests
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
...@@ -383,6 +383,43 @@ def rms_norm_(lib): ...@@ -383,6 +383,43 @@ def rms_norm_(lib):
] ]
@OpRegister.operator
def add_rms_norm_(lib):
lib.infiniopCreateAddRMSNormDescriptor.restype = c_int32
lib.infiniopCreateAddRMSNormDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
]
lib.infiniopGetAddRMSNormWorkspaceSize.restype = c_int32
lib.infiniopGetAddRMSNormWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopAddRMSNorm.restype = c_int32
lib.infiniopAddRMSNorm.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32
lib.infiniopDestroyAddRMSNormDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator @OpRegister.operator
def rope_(lib): def rope_(lib):
lib.infiniopCreateRoPEDescriptor.restype = c_int32 lib.infiniopCreateRoPEDescriptor.restype = c_int32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment