Commit 2a432b34 authored by zhuyue's avatar zhuyue Committed by gongchensu
Browse files

Unify add_rms_norm to always return (normalized_result, add_result) pair.

parent 7d60e5b8
...@@ -2,15 +2,19 @@ ...@@ -2,15 +2,19 @@
#include "../device.hpp" #include "../device.hpp"
#include "common/op.hpp" #include "common/op.hpp"
#include <utility>
namespace infinicore::op { namespace infinicore::op {
class AddRMSNorm { class AddRMSNorm {
public: public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float); using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher(); static common::OpDispatcher<schema> &dispatcher();
}; };
Tensor add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); // Fused Add and RMS Normalization
void add_rms_norm_(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); // Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
} // namespace infinicore::op } // namespace infinicore::op
...@@ -12,7 +12,8 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor( ...@@ -12,7 +12,8 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t weight_desc,
float epsilon); float epsilon,
infiniopTensorDescriptor_t residual_out_desc);
__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size); __C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);
...@@ -23,6 +24,7 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de ...@@ -23,6 +24,7 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de
const void *a, const void *a,
const void *b, const void *b,
const void *weight, const void *weight,
void *residual_out,
void *stream); void *stream);
__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc); __C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);
......
...@@ -40,7 +40,7 @@ from infinicore.dtype import ( ...@@ -40,7 +40,7 @@ from infinicore.dtype import (
uint8, uint8,
) )
from infinicore.ops.add import add from infinicore.ops.add import add
from infinicore.ops.add_rms_norm import add_rms_norm from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
from infinicore.ops.attention import attention from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul from infinicore.ops.mul import mul
...@@ -104,6 +104,7 @@ __all__ = [ ...@@ -104,6 +104,7 @@ __all__ = [
# Operations. # Operations.
"add", "add",
"add_rms_norm", "add_rms_norm",
"add_rms_norm_",
"attention", "attention",
"matmul", "matmul",
"mul", "mul",
......
...@@ -3,9 +3,29 @@ from infinicore.tensor import Tensor ...@@ -3,9 +3,29 @@ from infinicore.tensor import Tensor
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None): def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
"""
Fused Add and RMS Normalization.
Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
out: Optional output tuple (y, residual_out) for in-place operation
Returns:
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
"""
if out is None: if out is None:
return Tensor(_infinicore.add_rms_norm(a._underlying, b._underlying, weight._underlying, epsilon)) result = _infinicore.add_rms_norm(a._underlying, b._underlying, weight._underlying, epsilon)
return (Tensor(result[0]), Tensor(result[1]))
y, residual_out = out
_infinicore.add_rms_norm_(y._underlying, residual_out._underlying, a._underlying, b._underlying, weight._underlying, epsilon)
return (y, residual_out)
_infinicore.add_rms_norm_(out._underlying, a._underlying, b._underlying, weight._underlying, epsilon)
return out def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5):
"""In-place Fused Add and RMS Normalization."""
_infinicore.add_rms_norm_(y._underlying, residual_out._underlying, a._underlying, b._underlying, weight._underlying, epsilon)
...@@ -9,20 +9,21 @@ common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() { ...@@ -9,20 +9,21 @@ common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() {
return dispatcher_; return dispatcher_;
}; };
void AddRMSNorm::execute(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) { void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, a, b, weight); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight);
infinicore::context::setDevice(y->device()); infinicore::context::setDevice(y->device());
dispatcher().lookup(y->device().getType())(y, a, b, weight, epsilon); dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon);
} }
Tensor add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) { std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) {
auto y = Tensor::empty(a->shape(), a->dtype(), a->device()); auto y = Tensor::empty(a->shape(), a->dtype(), a->device());
add_rms_norm_(y, a, b, weight, epsilon); auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device());
return y; add_rms_norm_(y, residual_out, a, b, weight, epsilon);
return std::make_pair(y, residual_out);
} }
void add_rms_norm_(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) { void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
AddRMSNorm::execute(y, a, b, weight, epsilon); AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon);
} }
} // namespace infinicore::op } // namespace infinicore::op
...@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches( ...@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches(
} }
}); });
void calculate(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) { void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
size_t seed = hash_combine(y, a, b, weight, epsilon); size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon);
auto device = context::getDevice(); auto device = context::getDevice();
auto &cache = caches.getCache(device); auto &cache = caches.getCache(device);
...@@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) { ...@@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor(
context::getInfiniopHandle(device), &desc, context::getInfiniopHandle(device), &desc,
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon)); y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
desc = *desc_opt; desc = *desc_opt;
...@@ -39,7 +39,7 @@ void calculate(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) { ...@@ -39,7 +39,7 @@ void calculate(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
INFINICORE_CHECK_ERROR(infiniopAddRMSNorm( INFINICORE_CHECK_ERROR(infiniopAddRMSNorm(
desc, workspace->data(), workspace_size, desc, workspace->data(), workspace_size,
y->data(), a->data(), b->data(), weight->data(), context::getStream())); y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream()));
} }
static bool registered = []() { static bool registered = []() {
......
...@@ -24,12 +24,14 @@ Args: ...@@ -24,12 +24,14 @@ Args:
epsilon: Small constant for numerical stability, default is 1e-5 epsilon: Small constant for numerical stability, default is 1e-5
Returns: Returns:
Normalized tensor: RMSNorm(a + b) * weight Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
)doc"); )doc");
m.def("add_rms_norm_", m.def("add_rms_norm_",
&op::add_rms_norm_, &op::add_rms_norm_,
py::arg("y"), py::arg("y"),
py::arg("residual_out"),
py::arg("a"), py::arg("a"),
py::arg("b"), py::arg("b"),
py::arg("weight"), py::arg("weight"),
...@@ -37,7 +39,8 @@ Returns: ...@@ -37,7 +39,8 @@ Returns:
R"doc(In-place Fused Add and RMS Normalization. R"doc(In-place Fused Add and RMS Normalization.
Args: Args:
y: Output tensor y: Output tensor for normalized result
residual_out: Output tensor for add result (a + b) before normalization
a: First input tensor a: First input tensor
b: Second input tensor b: Second input tensor
weight: Scale weights weight: Scale weights
......
...@@ -36,7 +36,8 @@ ...@@ -36,7 +36,8 @@
infiniopTensorDescriptor_t a_desc, \ infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \ infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \ infiniopTensorDescriptor_t weight_desc, \
float epsilon); \ float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
\ \
infiniStatus_t calculate( \ infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \ void *workspace, size_t workspace_size, \
...@@ -44,6 +45,7 @@ ...@@ -44,6 +45,7 @@
const void *a, \ const void *a, \
const void *b, \ const void *b, \
const void *weight, \ const void *weight, \
void *residual_out, \
void *stream) const; \ void *stream) const; \
}; \ }; \
} }
......
...@@ -13,15 +13,16 @@ infiniStatus_t Descriptor::create( ...@@ -13,15 +13,16 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t weight_desc,
float epsilon) { float epsilon,
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon); infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result); CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
template <typename T> template <typename T>
infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w) { infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w, T *residual_out) {
const size_t batch_size = info->shape[0]; const size_t batch_size = info->shape[0];
const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1; const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1;
const size_t dim = info->dim(); const size_t dim = info->dim();
...@@ -35,12 +36,16 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T ...@@ -35,12 +36,16 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1]; const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1]; const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1]; T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
T *residual_out_ptr = info->has_residual_out ?
(residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1]) : nullptr;
// First, compute add(a, b) and store sum values // Compute add(a, b) once and store it
// We'll compute RMS norm directly on the sum
T sum_squared = (T)0; T sum_squared = (T)0;
for (size_t k = 0; k < dim; k++) { for (size_t k = 0; k < dim; k++) {
T sum_val = a_ptr[k] + b_ptr[k]; T sum_val = a_ptr[k] + b_ptr[k];
if (residual_out_ptr != nullptr) {
residual_out_ptr[k] = sum_val; // Store add result
}
sum_squared += sum_val * sum_val; sum_squared += sum_val * sum_val;
} }
...@@ -49,18 +54,26 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T ...@@ -49,18 +54,26 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
T rms = (T)1 / std::sqrt(sum_squared / (T)(dim) + (T)(info->epsilon)); T rms = (T)1 / std::sqrt(sum_squared / (T)(dim) + (T)(info->epsilon));
// Apply normalization: y = (a + b) * w * rms // Apply normalization: y = (a + b) * w * rms
// Recompute sum to avoid storing temporary array // Reuse the stored sum values if residual_out was computed, otherwise recompute
if (residual_out_ptr != nullptr) {
// Reuse stored values
for (size_t k = 0; k < dim; k++) {
y_ptr[k] = residual_out_ptr[k] * w[k] * rms;
}
} else {
// Recompute sum
for (size_t k = 0; k < dim; k++) { for (size_t k = 0; k < dim; k++) {
T sum_val = a_ptr[k] + b_ptr[k]; T sum_val = a_ptr[k] + b_ptr[k];
y_ptr[k] = sum_val * w[k] * rms; y_ptr[k] = sum_val * w[k] * rms;
} }
} }
}
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
template <typename T, typename Tw> template <typename T, typename Tw>
infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w) { infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w, T *residual_out) {
static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value, static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value,
"T must be fp16_t or bf16_t"); "T must be fp16_t or bf16_t");
...@@ -77,11 +90,16 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const ...@@ -77,11 +90,16 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1]; const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1]; const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1]; T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
T *residual_out_ptr = info->has_residual_out ?
(residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1]) : nullptr;
// Compute sum of squares for RMS normalization // Compute sum of squares for RMS normalization and store add result
float sum_squared = 0.0f; float sum_squared = 0.0f;
for (size_t k = 0; k < dim; k++) { for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]); float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
if (residual_out_ptr != nullptr) {
residual_out_ptr[k] = utils::cast<T>(sum_val); // Store add result
}
sum_squared += sum_val * sum_val; sum_squared += sum_val * sum_val;
} }
...@@ -89,6 +107,23 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const ...@@ -89,6 +107,23 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
float rms = 1.f / std::sqrt(sum_squared / (float)(dim) + info->epsilon); float rms = 1.f / std::sqrt(sum_squared / (float)(dim) + info->epsilon);
// Apply normalization: y = (a + b) * w * rms // Apply normalization: y = (a + b) * w * rms
// Reuse stored values if residual_out was computed, otherwise recompute
if (residual_out_ptr != nullptr) {
// Reuse stored values
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(residual_out_ptr[k]);
float val;
if constexpr (std::is_same<Tw, float>::value) {
val = sum_val * w[k] * rms;
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
val = sum_val * utils::cast<float>(w[k]) * rms;
} else {
std::abort();
}
y_ptr[k] = utils::cast<T>(val);
}
} else {
// Recompute sum
for (size_t k = 0; k < dim; k++) { for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]); float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
float val; float val;
...@@ -102,6 +137,7 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const ...@@ -102,6 +137,7 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
y_ptr[k] = utils::cast<T>(val); y_ptr[k] = utils::cast<T>(val);
} }
} }
}
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -109,31 +145,31 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const ...@@ -109,31 +145,31 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight, void *y, const void *a, const void *b, const void *weight,
void *stream) const { void *residual_out, void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) { if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) { if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight, (fp16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_F32) { } else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight, (fp16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_BF16) { } else if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight, (fp16_t *)residual_out));
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
} else if (_info.atype == INFINI_DTYPE_BF16) { } else if (_info.atype == INFINI_DTYPE_BF16) {
if (_info.wtype == INFINI_DTYPE_BF16) { if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight, (bf16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_F32) { } else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight, (bf16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_F16) { } else if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight, (bf16_t *)residual_out));
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
} else if (_info.atype == INFINI_DTYPE_F32) { } else if (_info.atype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight)); CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight, (float *)residual_out));
} else if (_info.atype == INFINI_DTYPE_F64) { } else if (_info.atype == INFINI_DTYPE_F64) {
CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight)); CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight, (double *)residual_out));
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
...@@ -18,6 +18,8 @@ public: ...@@ -18,6 +18,8 @@ public:
std::vector<ptrdiff_t> y_strides; std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> a_strides; std::vector<ptrdiff_t> a_strides;
std::vector<ptrdiff_t> b_strides; std::vector<ptrdiff_t> b_strides;
std::vector<ptrdiff_t> residual_out_strides;
bool has_residual_out;
size_t ndim() const { return shape.size(); } size_t ndim() const { return shape.size(); }
size_t dim() const { return shape[ndim() - 1]; } size_t dim() const { return shape[ndim() - 1]; }
...@@ -27,7 +29,8 @@ public: ...@@ -27,7 +29,8 @@ public:
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t weight_desc,
float epsilon) { float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto atype = y_desc->dtype(); auto atype = y_desc->dtype();
auto wtype = weight_desc->dtype(); auto wtype = weight_desc->dtype();
...@@ -95,6 +98,27 @@ public: ...@@ -95,6 +98,27 @@ public:
return INFINI_STATUS_BAD_TENSOR_STRIDES; return INFINI_STATUS_BAD_TENSOR_STRIDES;
} }
// Check residual_out_desc if provided
bool has_residual_out = (residual_out_desc != nullptr);
if (has_residual_out) {
const size_t residual_out_ndim = residual_out_desc->ndim();
if (residual_out_ndim != y_ndim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (residual_out_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// Check shape matches
for (size_t i = 0; i < y_ndim; i++) {
if (residual_out_desc->dim(i) != y_desc->dim(i)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (residual_out_desc->stride(residual_out_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
}
AddRMSNormInfo info; AddRMSNormInfo info;
info.wtype = wtype; info.wtype = wtype;
info.atype = atype; info.atype = atype;
...@@ -103,6 +127,10 @@ public: ...@@ -103,6 +127,10 @@ public:
info.y_strides = y_desc->strides(); info.y_strides = y_desc->strides();
info.a_strides = a_desc->strides(); info.a_strides = a_desc->strides();
info.b_strides = b_desc->strides(); info.b_strides = b_desc->strides();
info.has_residual_out = has_residual_out;
if (has_residual_out) {
info.residual_out_strides = residual_out_desc->strides();
}
return utils::Result<AddRMSNormInfo>(info); return utils::Result<AddRMSNormInfo>(info);
} }
}; };
......
...@@ -37,7 +37,8 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( ...@@ -37,7 +37,8 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t weight_desc,
float epsilon) { float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
#define CREATE(CASE, NAMESPACE) \ #define CREATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
...@@ -48,7 +49,8 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( ...@@ -48,7 +49,8 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
a_desc, \ a_desc, \
b_desc, \ b_desc, \
weight_desc, \ weight_desc, \
epsilon) epsilon, \
residual_out_desc)
switch (handle->device) { switch (handle->device) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
...@@ -118,12 +120,13 @@ __C infiniStatus_t infiniopAddRMSNorm( ...@@ -118,12 +120,13 @@ __C infiniStatus_t infiniopAddRMSNorm(
const void *a, const void *a,
const void *b, const void *b,
const void *weight, const void *weight,
void *residual_out,
void *stream) { void *stream) {
#define CALCULATE(CASE, NAMESPACE) \ #define CALCULATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \ return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, a, b, weight, stream) ->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream)
switch (desc->device_type) { switch (desc->device_type) {
......
...@@ -86,63 +86,35 @@ def parse_test_cases(): ...@@ -86,63 +86,35 @@ def parse_test_cases():
) # Weight is always contiguous ) # Weight is always contiguous
y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype) y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype)
# Test Case 1: Out-of-place (return value) # Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result)
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
test_cases.append( test_cases.append(
TestCase( TestCase(
inputs=[a_spec, b_spec, w_spec], inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON}, kwargs={"epsilon": _EPSILON},
output_spec=None, output_specs=[y_spec, residual_out_spec], # Two outputs
comparison_target=None, comparison_target=None,
tolerance=tolerance, tolerance=tolerance,
output_count=2, # Two outputs: normalized_result and add_result
description=f"AddRMSNorm - OUT_OF_PLACE", description=f"AddRMSNorm - OUT_OF_PLACE",
) )
) )
# Test Case 2: In-place with explicit output tensor (add_rms_norm(a, b, w, out=y)) # Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w))
if y_supports_inplace: if y_supports_inplace:
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
test_cases.append( test_cases.append(
TestCase( TestCase(
inputs=[a_spec, b_spec, w_spec], inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON}, kwargs={"epsilon": _EPSILON, "out": (y_spec, residual_out_spec)},
output_spec=y_spec, # Specify the output tensor spec output_specs=[y_spec, residual_out_spec], # Two outputs
comparison_target="out", comparison_target="out",
tolerance=tolerance, tolerance=tolerance,
output_count=2,
description=f"AddRMSNorm - INPLACE(out)", description=f"AddRMSNorm - INPLACE(out)",
) )
) )
# Test Case 3: In-place on first input (add_rms_norm(a, b, w, out=a))
if a_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={
"out": 0,
"epsilon": _EPSILON,
}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"AddRMSNorm - INPLACE(a)",
)
)
# Test Case 4: In-place on second input (add_rms_norm(a, b, w, out=b))
if b_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={
"out": 1,
"epsilon": _EPSILON,
}, # Use index 1 for second input
output_spec=None,
comparison_target=1, # Compare second input
tolerance=tolerance,
description=f"AddRMSNorm - INPLACE(b)",
)
)
return test_cases return test_cases
...@@ -156,7 +128,7 @@ class OpTest(BaseOperatorTest): ...@@ -156,7 +128,7 @@ class OpTest(BaseOperatorTest):
return parse_test_cases() return parse_test_cases()
def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
"""PyTorch AddRMSNorm implementation""" """PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)"""
input_dtype = a.dtype input_dtype = a.dtype
# Compute add(a, b) # Compute add(a, b)
...@@ -165,18 +137,27 @@ class OpTest(BaseOperatorTest): ...@@ -165,18 +137,27 @@ class OpTest(BaseOperatorTest):
# Calculate RMSNorm: (a + b) * weight / sqrt(mean((a+b)^2) + epsilon) # Calculate RMSNorm: (a + b) * weight / sqrt(mean((a+b)^2) + epsilon)
variance = sum_tensor.pow(2).mean(-1, keepdim=True) variance = sum_tensor.pow(2).mean(-1, keepdim=True)
result = sum_tensor * torch.rsqrt(variance + epsilon) * weight_fp32 normalized_result = sum_tensor * torch.rsqrt(variance + epsilon) * weight_fp32
# Convert back to original dtype # Convert back to original dtype
result = result.to(input_dtype) normalized_result = normalized_result.to(input_dtype)
add_result = sum_tensor.to(input_dtype)
if out is not None: if out is not None:
out.copy_(result) # For in-place operations, we need to handle the output tuple
if isinstance(out, (tuple, list)) and len(out) == 2:
out[0].copy_(normalized_result)
out[1].copy_(add_result)
return tuple(out)
else:
# Single output - just return normalized result for backward compatibility
out.copy_(normalized_result)
return out return out
return result
return (normalized_result, add_result)
def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
"""InfiniCore AddRMSNorm implementation""" """InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)"""
return infinicore.add_rms_norm(a, b, weight, epsilon, out=out) return infinicore.add_rms_norm(a, b, weight, epsilon, out=out)
......
...@@ -91,6 +91,7 @@ def test( ...@@ -91,6 +91,7 @@ def test(
) )
y = TestTensor(y_shape, y_stride, dtype, device, mode="ones") y = TestTensor(y_shape, y_stride, dtype, device, mode="ones")
residual_out = TestTensor(a_shape, a_stride, dtype, device, mode="ones")
a = TestTensor(a_shape, a_stride, dtype, device, scale=0.01) a = TestTensor(a_shape, a_stride, dtype, device, scale=0.01)
b = TestTensor(b_shape, b_stride, dtype, device, scale=0.01) b = TestTensor(b_shape, b_stride, dtype, device, scale=0.01)
w = TestTensor(w_shape, None, w_dtype, device) w = TestTensor(w_shape, None, w_dtype, device)
...@@ -112,11 +113,12 @@ def test( ...@@ -112,11 +113,12 @@ def test(
b.descriptor, b.descriptor,
w.descriptor, w.descriptor,
eps, eps,
residual_out.descriptor,
) )
) )
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a, b, y, w]: for tensor in [a, b, y, w, residual_out]:
tensor.destroy_desc() tensor.destroy_desc()
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
...@@ -137,6 +139,7 @@ def test( ...@@ -137,6 +139,7 @@ def test(
a.data(), a.data(),
b.data(), b.data(),
w.data(), w.data(),
residual_out.data(),
None, None,
) )
) )
...@@ -144,10 +147,19 @@ def test( ...@@ -144,10 +147,19 @@ def test(
lib_add_rms_norm() lib_add_rms_norm()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# Verify normalized result (y)
if DEBUG: if DEBUG:
debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
# Verify add result (residual_out) - should be a + b
expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(torch.float32)
expected_residual = expected_residual.to(a.torch_tensor().dtype)
if DEBUG:
debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
assert torch.allclose(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment