Unverified Commit 3b5afffe authored by Haojie Wang's avatar Haojie Wang Committed by GitHub
Browse files

Merge pull request #842 from gongchensu/Issue/791

Issue/791 增加add_rms_norm融合算子
parents 2d9d5c30 7712471f
#pragma once #pragma once
#include "ops/add.hpp" #include "ops/add.hpp"
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp" #include "ops/attention.hpp"
#include "ops/causal_softmax.hpp" #include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp" #include "ops/matmul.hpp"
......
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
#include <utility>
namespace infinicore::op {
class AddRMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};
// Fused Add and RMS Normalization
// Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
} // namespace infinicore::op
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "infiniop/handle.h" #include "infiniop/handle.h"
#include "infiniop/ops/add.h" #include "infiniop/ops/add.h"
#include "infiniop/ops/add_rms_norm.h"
#include "infiniop/ops/attention.h" #include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/clip.h" #include "infiniop/ops/clip.h"
......
#ifndef __INFINIOP_ADD_RMS_NORM_API_H__
#define __INFINIOP_ADD_RMS_NORM_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopAddRMSNormDescriptor_t;
__C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc);
__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *a,
const void *b,
const void *weight,
void *residual_out,
void *stream);
__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);
#endif
...@@ -40,6 +40,7 @@ from infinicore.dtype import ( ...@@ -40,6 +40,7 @@ from infinicore.dtype import (
uint8, uint8,
) )
from infinicore.ops.add import add from infinicore.ops.add import add
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
from infinicore.ops.attention import attention from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul from infinicore.ops.mul import mul
...@@ -105,6 +106,8 @@ __all__ = [ ...@@ -105,6 +106,8 @@ __all__ = [
"uint8", "uint8",
# Operations. # Operations.
"add", "add",
"add_rms_norm",
"add_rms_norm_",
"attention", "attention",
"matmul", "matmul",
"mul", "mul",
......
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
"""
Fused Add and RMS Normalization.
Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
out: Optional output tuple (y, residual_out) for in-place operation
Returns:
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
"""
if out is None:
result = _infinicore.add_rms_norm(
a._underlying, b._underlying, weight._underlying, epsilon
)
return (Tensor(result[0]), Tensor(result[1]))
y, residual_out = out
_infinicore.add_rms_norm_(
y._underlying,
residual_out._underlying,
a._underlying,
b._underlying,
weight._underlying,
epsilon,
)
return (y, residual_out)
def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5):
"""In-place Fused Add and RMS Normalization."""
_infinicore.add_rms_norm_(
y._underlying,
residual_out._underlying,
a._underlying,
b._underlying,
weight._underlying,
epsilon,
)
#include "infinicore/ops/add_rms_norm.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() {
static common::OpDispatcher<AddRMSNorm::schema> dispatcher_;
return dispatcher_;
};
void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight);
infinicore::context::setDevice(y->device());
dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon);
}
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) {
auto y = Tensor::empty(a->shape(), a->dtype(), a->device());
auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device());
add_rms_norm_(y, residual_out, a, b, weight, epsilon);
return std::make_pair(y, residual_out);
}
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add_rms_norm.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>
namespace infinicore::op::add_rms_norm_impl::infiniop {
thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches(
100, // capacity
[](infiniopAddRMSNormDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopAddRMSNormDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor(
context::getInfiniopHandle(device), &desc,
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopAddRMSNorm(
desc, workspace->data(), workspace_size,
y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream()));
}
static bool registered = []() {
AddRMSNorm::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::add_rms_norm_impl::infiniop
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "ops/add.hpp" #include "ops/add.hpp"
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp" #include "ops/attention.hpp"
#include "ops/causal_softmax.hpp" #include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp" #include "ops/embedding.hpp"
...@@ -24,6 +25,7 @@ namespace infinicore::ops { ...@@ -24,6 +25,7 @@ namespace infinicore::ops {
inline void bind(py::module &m) { inline void bind(py::module &m) {
bind_add(m); bind_add(m);
bind_add_rms_norm(m);
bind_attention(m); bind_attention(m);
bind_causal_softmax(m); bind_causal_softmax(m);
bind_random_sample(m); bind_random_sample(m);
......
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/add_rms_norm.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_add_rms_norm(py::module &m) {
m.def("add_rms_norm",
&op::add_rms_norm,
py::arg("a"),
py::arg("b"),
py::arg("weight"),
py::arg("epsilon") = 1e-5f,
R"doc(Fused Add and RMS Normalization.
Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
Returns:
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
)doc");
m.def("add_rms_norm_",
&op::add_rms_norm_,
py::arg("y"),
py::arg("residual_out"),
py::arg("a"),
py::arg("b"),
py::arg("weight"),
py::arg("epsilon") = 1e-5f,
R"doc(In-place Fused Add and RMS Normalization.
Args:
y: Output tensor for normalized result
residual_out: Output tensor for add result (a + b) before normalization
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
)doc");
}
} // namespace infinicore::ops
#ifndef ADD_RMS_NORM_H
#define ADD_RMS_NORM_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::add_rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
AddRMSNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
AddRMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *a, \
const void *b, \
const void *weight, \
void *residual_out, \
void *stream) const; \
}; \
}
#endif // ADD_RMS_NORM_H
#include "add_rms_norm_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
namespace op::add_rms_norm::cpu {
Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename T>
infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w, T *residual_out) {
const size_t batch_size = info->shape[0];
const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1;
const size_t dim = info->dim();
const ptrdiff_t total_blocks = static_cast<ptrdiff_t>(batch_size * nhead);
#pragma omp parallel for
for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) {
const size_t i = block_idx / nhead; // batch index
const size_t j = block_idx % nhead; // head index
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
T *residual_out_ptr = residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1];
// Compute add(a, b) once and store it
T sum_squared = (T)0;
for (size_t k = 0; k < dim; k++) {
T sum_val = a_ptr[k] + b_ptr[k];
residual_out_ptr[k] = sum_val; // Store add result
sum_squared += sum_val * sum_val;
}
// Compute RMS: 1 / (sqrt(mean(sum^2) + eps))
// Note: mean = sum_squared / dim
T rms = (T)1 / std::sqrt(sum_squared / (T)(dim) + (T)(info->epsilon));
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for (size_t k = 0; k < dim; k++) {
y_ptr[k] = residual_out_ptr[k] * w[k] * rms;
}
}
return INFINI_STATUS_SUCCESS;
}
template <typename T, typename Tw>
infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w, T *residual_out) {
static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value,
"T must be fp16_t or bf16_t");
const size_t batch_size = info->shape[0];
const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1;
const size_t dim = info->dim();
const ptrdiff_t total_blocks = static_cast<ptrdiff_t>(batch_size * nhead);
#pragma omp parallel for
for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) {
const size_t i = block_idx / nhead; // batch index
const size_t j = block_idx % nhead; // head index
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
T *residual_out_ptr = residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1];
// Compute sum of squares for RMS normalization and store add result
float sum_squared = 0.0f;
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
residual_out_ptr[k] = utils::cast<T>(sum_val); // Store add result
sum_squared += sum_val * sum_val;
}
// Compute RMS: 1 / (sqrt(sum/dim + eps))
float rms = 1.f / std::sqrt(sum_squared / (float)(dim) + info->epsilon);
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(residual_out_ptr[k]);
float val;
if constexpr (std::is_same<Tw, float>::value) {
val = sum_val * w[k] * rms;
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
val = sum_val * utils::cast<float>(w[k]) * rms;
} else {
std::abort();
}
y_ptr[k] = utils::cast<T>(val);
}
}
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *residual_out, void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight, (fp16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight, (fp16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight, (fp16_t *)residual_out));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_BF16) {
if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight, (bf16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight, (bf16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight, (bf16_t *)residual_out));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight, (float *)residual_out));
} else if (_info.atype == INFINI_DTYPE_F64) {
CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight, (double *)residual_out));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::cpu
#ifndef __ADD_RMS_NORM_CPU_H__
#define __ADD_RMS_NORM_CPU_H__
#include "../add_rms_norm.h"
DESCRIPTOR(cpu)
#endif
#ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
#define __ADD_RMS_NORM_CUDA_KERNEL_H__
#include <cub/block/block_reduce.cuh>
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
__device__ void add_rmsnormBlock(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
// Each block takes care of one head in one batch
// Each thread deals with every block_size element in the row
size_t batch_idx = blockIdx.x / nhead;
size_t head_idx = blockIdx.x % nhead;
auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
auto a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
auto b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
auto w_ptr = w;
Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
// Compute add(a, b) and sum of squares in one pass
Tcompute sum_squared = 0;
for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
Tcompute sum_val = Tcompute(a_ptr[i]) + Tcompute(b_ptr[i]);
residual_out_ptr[i] = Tdata(sum_val); // Store add result
sum_squared += sum_val * sum_val;
}
// Block-reduce sum of squares
using BlockReduce = cub::BlockReduce<Tcompute, BLOCK_SIZE>;
__shared__ typename BlockReduce::TempStorage temp_storage;
sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
__shared__ Tcompute rms;
if (threadIdx.x == 0) {
rms = Tcompute(rsqrtf(sum_squared / Tcompute(dim) + epsilon));
}
__syncthreads();
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
Tcompute sum_val = Tcompute(residual_out_ptr[i]); // Reuse stored value
y_ptr[i] = Tdata(sum_val * Tcompute(w_ptr[i]) * rms);
}
}
#endif
#ifndef __ADD_RMS_NORM_INFO_H__
#define __ADD_RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace op::add_rms_norm {
class AddRMSNormInfo {
AddRMSNormInfo() = default;
public:
infiniDtype_t wtype;
infiniDtype_t atype;
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> a_strides;
std::vector<ptrdiff_t> b_strides;
std::vector<ptrdiff_t> residual_out_strides;
bool has_residual_out;
size_t ndim() const { return shape.size(); }
size_t dim() const { return shape[ndim() - 1]; }
static utils::Result<AddRMSNormInfo> create(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto atype = y_desc->dtype();
auto wtype = weight_desc->dtype();
// Check that all input tensors have the same dtype
if (a_desc->dtype() != atype || b_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16 || atype == INFINI_DTYPE_BF16) {
// For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32
if (wtype != atype && wtype != INFINI_DTYPE_F32 && wtype != INFINI_DTYPE_BF16 && wtype != INFINI_DTYPE_F16) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) {
// For FP32/FP64, activations and weights must be of the same type
if (atype != wtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
const size_t y_ndim = y_desc->ndim();
const size_t a_ndim = a_desc->ndim();
const size_t b_ndim = b_desc->ndim();
const size_t w_ndim = weight_desc->ndim();
if (y_ndim != a_ndim || y_ndim != b_ndim || w_ndim != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch = 1;
size_t nhead = 1;
size_t dim = 0;
if (y_ndim == 2) {
batch = y_desc->dim(0);
dim = y_desc->dim(1);
if (a_desc->dim(0) != batch || a_desc->dim(1) != dim || b_desc->dim(0) != batch || b_desc->dim(1) != dim || weight_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else if (y_ndim == 3) {
batch = y_desc->dim(0);
nhead = y_desc->dim(1);
dim = y_desc->dim(2);
if (a_desc->dim(0) != batch || a_desc->dim(1) != nhead || a_desc->dim(2) != dim || b_desc->dim(0) != batch || b_desc->dim(1) != nhead || b_desc->dim(2) != dim || weight_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Check contiguity of the last dimension
if (y_desc->stride(y_ndim - 1) != 1 || a_desc->stride(a_ndim - 1) != 1 || b_desc->stride(b_ndim - 1) != 1 || weight_desc->stride(w_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
// residual_out_desc is required (always needed for fused operator)
if (residual_out_desc == nullptr) {
return INFINI_STATUS_BAD_PARAM;
}
const size_t residual_out_ndim = residual_out_desc->ndim();
if (residual_out_ndim != y_ndim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (residual_out_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// Check shape matches
for (size_t i = 0; i < y_ndim; i++) {
if (residual_out_desc->dim(i) != y_desc->dim(i)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (residual_out_desc->stride(residual_out_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
AddRMSNormInfo info;
info.wtype = wtype;
info.atype = atype;
info.epsilon = epsilon;
info.shape = y_desc->shape();
info.y_strides = y_desc->strides();
info.a_strides = a_desc->strides();
info.b_strides = b_desc->strides();
info.has_residual_out = true; // Always true now
info.residual_out_strides = residual_out_desc->strides();
return utils::Result<AddRMSNormInfo>(info);
}
};
} // namespace op::add_rms_norm
#endif // __ADD_RMS_NORM_INFO_H__
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "add_rms_norm_nvidia.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_CUDA_KERNEL add_rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, dim, epsilon);
}
namespace op::add_rms_norm::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
cudaStream_t cuda_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __nv_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__nv_bfloat16, __nv_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__nv_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__nv_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *residual_out, void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stride_a_batch = _info.a_strides[0];
auto stride_a_nhead = _info.a_strides[1];
auto stride_b_batch = _info.b_strides[0];
auto stride_b_nhead = _info.b_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto stride_residual_out_batch = _info.residual_out_strides[0];
auto stride_residual_out_nhead = _info.residual_out_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_4096>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, cuda_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::nvidia
#ifndef __ADD_RMS_NORM_NVIDIA_CUDA_H__
#define __ADD_RMS_NORM_NVIDIA_CUDA_H__
#include "../add_rms_norm.h"
DESCRIPTOR(nvidia)
#endif
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/add_rms_norm.h"
#ifdef ENABLE_CPU_API
#include "cpu/add_rms_norm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#include "nvidia/add_rms_norm_nvidia.cuh"
#endif
#ifdef ENABLE_ASCEND_API
// TODO: Add Ascend implementation
// #include "ascend/add_rms_norm_aclnn.h"
#endif
#ifdef ENABLE_CAMBRICON_API
// TODO: Add Cambricon implementation
// #include "bang/add_rms_norm_bang.h"
#endif
#ifdef ENABLE_METAX_API
// TODO: Add Metax implementation
// #include "metax/add_rms_norm_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
// TODO: Add Moore implementation
// #include "moore/add_rms_norm_moore.h"
#endif
#ifdef ENABLE_KUNLUN_API
// TODO: Add Kunlun implementation
// #include "kunlun/add_rms_norm_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::add_rms_norm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
a_desc, \
b_desc, \
weight_desc, \
epsilon, \
residual_out_desc)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopAddRMSNorm(
infiniopAddRMSNormDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *a,
const void *b,
const void *weight,
void *residual_out,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc) {
if (desc == nullptr) {
return INFINI_STATUS_SUCCESS;
}
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DESTROY(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
DESTROY(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
}
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (y_shape, a_shape, b_shape, w_shape, y_strides, a_strides, b_strides)
_TEST_CASES_DATA = [
# Basic cases
((1, 4), (1, 4), (1, 4), (4,), None, None, None),
((2, 4), (2, 4), (2, 4), (4,), None, None, None),
((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), None, None, None),
# Strided cases
((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1), (12, 8, 1)),
# Large tensors
((16, 2048), (16, 2048), (16, 2048), (2048,), None, None, None),
((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)),
]
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 2e-3, "rtol": 2e-3},
infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-4},
}
# Data types for individual tensors
_INPUT_DTYPES = [infinicore.float16, infinicore.bfloat16]
_WEIGHT_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# EPSILON constant for AddRMSNorm
_EPSILON = 1e-5
def parse_test_cases():
"""
Parse AddRMSNorm test case data and return list of TestCase objects.
Format: (y_shape, a_shape, b_shape, w_shape, y_strides, a_strides, b_strides)
"""
test_cases = []
for data in _TEST_CASES_DATA:
y_shape = data[0] # Output shape
a_shape = data[1] # First input shape
b_shape = data[2] # Second input shape
w_shape = data[3] # Weight shape (1D)
y_strides = data[4] if len(data) > 4 else None
a_strides = data[5] if len(data) > 5 else None
b_strides = data[6] if len(data) > 6 else None
# Check if tensors support in-place operations
a_supports_inplace = not is_broadcast(a_strides)
b_supports_inplace = not is_broadcast(b_strides)
y_supports_inplace = not is_broadcast(y_strides)
# Generate test cases for all dtype combinations
for input_dtype in _INPUT_DTYPES:
for weight_dtype in _WEIGHT_DTYPES:
# Use input dtype tolerance for output
tolerance = _TOLERANCE_MAP.get(
input_dtype, {"atol": 1e-5, "rtol": 1e-4}
)
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
b_spec = TensorSpec.from_tensor(b_shape, b_strides, input_dtype)
w_spec = TensorSpec.from_tensor(
w_shape, None, weight_dtype
) # Weight is always contiguous
y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype)
# Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result)
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON},
output_specs=[y_spec, residual_out_spec], # Two outputs
comparison_target=None,
tolerance=tolerance,
output_count=2, # Two outputs: normalized_result and add_result
description=f"AddRMSNorm - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w))
if y_supports_inplace:
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON, "out": (y_spec, residual_out_spec)},
output_specs=[y_spec, residual_out_spec], # Two outputs
comparison_target="out",
tolerance=tolerance,
output_count=2,
description=f"AddRMSNorm - INPLACE(out)",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""AddRMSNorm operator test with simplified implementation"""
def __init__(self):
super().__init__("AddRMSNorm")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
"""PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)"""
input_dtype = a.dtype
# Compute add(a, b)
sum_tensor = a.to(torch.float32) + b.to(torch.float32)
weight_fp32 = weight.to(torch.float32)
# Calculate RMSNorm: (a + b) * weight / sqrt(mean((a+b)^2) + epsilon)
variance = sum_tensor.pow(2).mean(-1, keepdim=True)
normalized_result = sum_tensor * torch.rsqrt(variance + epsilon) * weight_fp32
# Convert back to original dtype
normalized_result = normalized_result.to(input_dtype)
add_result = sum_tensor.to(input_dtype)
if out is not None:
# For in-place operations, we need to handle the output tuple
if isinstance(out, (tuple, list)) and len(out) == 2:
out[0].copy_(normalized_result)
out[1].copy_(add_result)
return tuple(out)
else:
# Single output - just return normalized result for backward compatibility
out.copy_(normalized_result)
return out
return (normalized_result, add_result)
def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
"""InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)"""
return infinicore.add_rms_norm(a, b, weight, epsilon, out=out)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# y_shape, a_shape, b_shape, w_shape, y_stride, a_stride, b_stride
((1, 4), (1, 4), (1, 4), (4,), None, None, None),
((2, 4), (2, 4), (2, 4), (4,), None, None, None),
((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), None, None, None),
((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1), (12, 8, 1)),
((16, 2048), (16, 2048), (16, 2048), (2048,), None, None, None),
((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((15, 8192), (15, 8192), (15, 8192), (8192,), None, None, None),
]
# w (weight) types
# Note: 'None' means the same as input dtype
_WEIGHT_DTYPES = [None, InfiniDtype.F32, InfiniDtype.F16, InfiniDtype.BF16]
# a, b types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16]
# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES = [
test_case + (w_dtype,) for test_case in _TEST_CASES_ for w_dtype in _WEIGHT_DTYPES
]
# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3},
InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
def add_rms_norm(ans, a, b, w, eps):
input_dtype = a.dtype
# Compute add(a, b)
sum_tensor = a.to(torch.float32) + b.to(torch.float32)
# Compute RMS normalization
scale = sum_tensor.pow(2).mean(-1, keepdim=True).add_(eps).rsqrt_()
ans.set_((sum_tensor.mul_(scale).mul_(w.to(torch.float32))).to(input_dtype))
def test(
handle,
device,
y_shape,
a_shape,
b_shape,
w_shape,
y_stride,
a_stride,
b_stride,
w_dtype=InfiniDtype.F32,
dtype=InfiniDtype.F16,
sync=None,
):
w_dtype = w_dtype if w_dtype else dtype
print(
f"Testing AddRMSNorm on {InfiniDeviceNames[device]} with y_shape:{y_shape} a_shape:{a_shape} b_shape:{b_shape} w_shape:{w_shape}"
f" y_stride:{y_stride} a_stride:{a_stride} b_stride:{b_stride} w_dtype:{InfiniDtypeNames[w_dtype]} dtype:{InfiniDtypeNames[dtype]}"
)
y = TestTensor(y_shape, y_stride, dtype, device, mode="ones")
residual_out = TestTensor(a_shape, a_stride, dtype, device, mode="ones")
a = TestTensor(a_shape, a_stride, dtype, device, scale=0.01)
b = TestTensor(b_shape, b_stride, dtype, device, scale=0.01)
w = TestTensor(w_shape, None, w_dtype, device)
eps = 1e-6
add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps)
if sync is not None:
sync()
descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreateAddRMSNormDescriptor(
handle,
ctypes.byref(descriptor),
y.descriptor,
a.descriptor,
b.descriptor,
w.descriptor,
eps,
residual_out.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a, b, y, w, residual_out]:
tensor.destroy_desc()
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetAddRMSNormWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, y.device)
def lib_add_rms_norm():
check_error(
LIBINFINIOP.infiniopAddRMSNorm(
descriptor,
workspace.data(),
workspace_size.value,
y.data(),
a.data(),
b.data(),
w.data(),
residual_out.data(),
None,
)
)
lib_add_rms_norm()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# Verify normalized result (y)
if DEBUG:
debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
# Verify add result (residual_out) - should be a + b
expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(torch.float32)
expected_residual = expected_residual.to(a.torch_tensor().dtype)
if DEBUG:
debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
assert torch.allclose(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_add_rms_norm(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(LIBINFINIOP.infiniopDestroyAddRMSNormDescriptor(descriptor))
if __name__ == "__main__":
args = get_args()
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
# Execute tests
for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment