Commit 2a432b34 authored by zhuyue's avatar zhuyue Committed by gongchensu
Browse files

Unify add_rms_norm to always return (normalized_result, add_result) pair.

parent 7d60e5b8
......@@ -2,15 +2,19 @@
#include "../device.hpp"
#include "common/op.hpp"
#include <utility>
namespace infinicore::op {
class AddRMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
// Fused Add and RMS Normalization
// Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
} // namespace infinicore::op
......@@ -12,7 +12,8 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon);
float epsilon,
infiniopTensorDescriptor_t residual_out_desc);
__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);
......@@ -23,6 +24,7 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de
const void *a,
const void *b,
const void *weight,
void *residual_out,
void *stream);
__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);
......
......@@ -40,7 +40,7 @@ from infinicore.dtype import (
uint8,
)
from infinicore.ops.add import add
from infinicore.ops.add_rms_norm import add_rms_norm
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
......@@ -104,6 +104,7 @@ __all__ = [
# Operations.
"add",
"add_rms_norm",
"add_rms_norm_",
"attention",
"matmul",
"mul",
......
......@@ -3,9 +3,29 @@ from infinicore.tensor import Tensor
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
"""
Fused Add and RMS Normalization.
Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
out: Optional output tuple (y, residual_out) for in-place operation
Returns:
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
"""
if out is None:
return Tensor(_infinicore.add_rms_norm(a._underlying, b._underlying, weight._underlying, epsilon))
result = _infinicore.add_rms_norm(a._underlying, b._underlying, weight._underlying, epsilon)
return (Tensor(result[0]), Tensor(result[1]))
y, residual_out = out
_infinicore.add_rms_norm_(y._underlying, residual_out._underlying, a._underlying, b._underlying, weight._underlying, epsilon)
return (y, residual_out)
_infinicore.add_rms_norm_(out._underlying, a._underlying, b._underlying, weight._underlying, epsilon)
return out
def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5):
"""In-place Fused Add and RMS Normalization."""
_infinicore.add_rms_norm_(y._underlying, residual_out._underlying, a._underlying, b._underlying, weight._underlying, epsilon)
......@@ -9,20 +9,21 @@ common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() {
return dispatcher_;
};
void AddRMSNorm::execute(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, a, b, weight);
void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight);
infinicore::context::setDevice(y->device());
dispatcher().lookup(y->device().getType())(y, a, b, weight, epsilon);
dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon);
}
Tensor add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) {
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) {
auto y = Tensor::empty(a->shape(), a->dtype(), a->device());
add_rms_norm_(y, a, b, weight, epsilon);
return y;
auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device());
add_rms_norm_(y, residual_out, a, b, weight, epsilon);
return std::make_pair(y, residual_out);
}
void add_rms_norm_(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
AddRMSNorm::execute(y, a, b, weight, epsilon);
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon);
}
} // namespace infinicore::op
......@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches(
}
});
void calculate(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
size_t seed = hash_combine(y, a, b, weight, epsilon);
void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
......@@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor(
context::getInfiniopHandle(device), &desc,
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon));
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
......@@ -39,7 +39,7 @@ void calculate(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
INFINICORE_CHECK_ERROR(infiniopAddRMSNorm(
desc, workspace->data(), workspace_size,
y->data(), a->data(), b->data(), weight->data(), context::getStream()));
y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream()));
}
static bool registered = []() {
......
......@@ -24,12 +24,14 @@ Args:
epsilon: Small constant for numerical stability, default is 1e-5
Returns:
Normalized tensor: RMSNorm(a + b) * weight
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
)doc");
m.def("add_rms_norm_",
&op::add_rms_norm_,
py::arg("y"),
py::arg("residual_out"),
py::arg("a"),
py::arg("b"),
py::arg("weight"),
......@@ -37,7 +39,8 @@ Returns:
R"doc(In-place Fused Add and RMS Normalization.
Args:
y: Output tensor
y: Output tensor for normalized result
residual_out: Output tensor for add result (a + b) before normalization
a: First input tensor
b: Second input tensor
weight: Scale weights
......
......@@ -36,7 +36,8 @@
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon); \
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
......@@ -44,6 +45,7 @@
const void *a, \
const void *b, \
const void *weight, \
void *residual_out, \
void *stream) const; \
}; \
}
......
......@@ -13,15 +13,16 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon);
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename T>
infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w) {
infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w, T *residual_out) {
const size_t batch_size = info->shape[0];
const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1;
const size_t dim = info->dim();
......@@ -35,12 +36,16 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
T *residual_out_ptr = info->has_residual_out ?
(residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1]) : nullptr;
// First, compute add(a, b) and store sum values
// We'll compute RMS norm directly on the sum
// Compute add(a, b) once and store it
T sum_squared = (T)0;
for (size_t k = 0; k < dim; k++) {
T sum_val = a_ptr[k] + b_ptr[k];
if (residual_out_ptr != nullptr) {
residual_out_ptr[k] = sum_val; // Store add result
}
sum_squared += sum_val * sum_val;
}
......@@ -49,10 +54,18 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
T rms = (T)1 / std::sqrt(sum_squared / (T)(dim) + (T)(info->epsilon));
// Apply normalization: y = (a + b) * w * rms
// Recompute sum to avoid storing temporary array
for (size_t k = 0; k < dim; k++) {
T sum_val = a_ptr[k] + b_ptr[k];
y_ptr[k] = sum_val * w[k] * rms;
// Reuse the stored sum values if residual_out was computed, otherwise recompute
if (residual_out_ptr != nullptr) {
// Reuse stored values
for (size_t k = 0; k < dim; k++) {
y_ptr[k] = residual_out_ptr[k] * w[k] * rms;
}
} else {
// Recompute sum
for (size_t k = 0; k < dim; k++) {
T sum_val = a_ptr[k] + b_ptr[k];
y_ptr[k] = sum_val * w[k] * rms;
}
}
}
......@@ -60,7 +73,7 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
}
template <typename T, typename Tw>
infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w) {
infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w, T *residual_out) {
static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value,
"T must be fp16_t or bf16_t");
......@@ -77,11 +90,16 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
T *residual_out_ptr = info->has_residual_out ?
(residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1]) : nullptr;
// Compute sum of squares for RMS normalization
// Compute sum of squares for RMS normalization and store add result
float sum_squared = 0.0f;
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
if (residual_out_ptr != nullptr) {
residual_out_ptr[k] = utils::cast<T>(sum_val); // Store add result
}
sum_squared += sum_val * sum_val;
}
......@@ -89,17 +107,35 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
float rms = 1.f / std::sqrt(sum_squared / (float)(dim) + info->epsilon);
// Apply normalization: y = (a + b) * w * rms
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
float val;
if constexpr (std::is_same<Tw, float>::value) {
val = sum_val * w[k] * rms;
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
val = sum_val * utils::cast<float>(w[k]) * rms;
} else {
std::abort();
// Reuse stored values if residual_out was computed, otherwise recompute
if (residual_out_ptr != nullptr) {
// Reuse stored values
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(residual_out_ptr[k]);
float val;
if constexpr (std::is_same<Tw, float>::value) {
val = sum_val * w[k] * rms;
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
val = sum_val * utils::cast<float>(w[k]) * rms;
} else {
std::abort();
}
y_ptr[k] = utils::cast<T>(val);
}
} else {
// Recompute sum
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
float val;
if constexpr (std::is_same<Tw, float>::value) {
val = sum_val * w[k] * rms;
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
val = sum_val * utils::cast<float>(w[k]) * rms;
} else {
std::abort();
}
y_ptr[k] = utils::cast<T>(val);
}
y_ptr[k] = utils::cast<T>(val);
}
}
......@@ -109,31 +145,31 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *stream) const {
void *residual_out, void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight, (fp16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight, (fp16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight, (fp16_t *)residual_out));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_BF16) {
if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight, (bf16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight, (bf16_t *)residual_out));
} else if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight, (bf16_t *)residual_out));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight));
CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight, (float *)residual_out));
} else if (_info.atype == INFINI_DTYPE_F64) {
CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight));
CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight, (double *)residual_out));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -18,6 +18,8 @@ public:
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> a_strides;
std::vector<ptrdiff_t> b_strides;
std::vector<ptrdiff_t> residual_out_strides;
bool has_residual_out;
size_t ndim() const { return shape.size(); }
size_t dim() const { return shape[ndim() - 1]; }
......@@ -27,7 +29,8 @@ public:
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto atype = y_desc->dtype();
auto wtype = weight_desc->dtype();
......@@ -95,6 +98,27 @@ public:
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
// Check residual_out_desc if provided
bool has_residual_out = (residual_out_desc != nullptr);
if (has_residual_out) {
const size_t residual_out_ndim = residual_out_desc->ndim();
if (residual_out_ndim != y_ndim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (residual_out_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// Check shape matches
for (size_t i = 0; i < y_ndim; i++) {
if (residual_out_desc->dim(i) != y_desc->dim(i)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (residual_out_desc->stride(residual_out_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
}
AddRMSNormInfo info;
info.wtype = wtype;
info.atype = atype;
......@@ -103,6 +127,10 @@ public:
info.y_strides = y_desc->strides();
info.a_strides = a_desc->strides();
info.b_strides = b_desc->strides();
info.has_residual_out = has_residual_out;
if (has_residual_out) {
info.residual_out_strides = residual_out_desc->strides();
}
return utils::Result<AddRMSNormInfo>(info);
}
};
......
......@@ -37,7 +37,8 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon) {
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
......@@ -48,7 +49,8 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
a_desc, \
b_desc, \
weight_desc, \
epsilon)
epsilon, \
residual_out_desc)
switch (handle->device) {
#ifdef ENABLE_CPU_API
......@@ -118,12 +120,13 @@ __C infiniStatus_t infiniopAddRMSNorm(
const void *a,
const void *b,
const void *weight,
void *residual_out,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, a, b, weight, stream)
->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream)
switch (desc->device_type) {
......
......@@ -86,63 +86,35 @@ def parse_test_cases():
) # Weight is always contiguous
y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype)
# Test Case 1: Out-of-place (return value)
# Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result)
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON},
output_spec=None,
output_specs=[y_spec, residual_out_spec], # Two outputs
comparison_target=None,
tolerance=tolerance,
output_count=2, # Two outputs: normalized_result and add_result
description=f"AddRMSNorm - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (add_rms_norm(a, b, w, out=y))
# Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w))
if y_supports_inplace:
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON},
output_spec=y_spec, # Specify the output tensor spec
kwargs={"epsilon": _EPSILON, "out": (y_spec, residual_out_spec)},
output_specs=[y_spec, residual_out_spec], # Two outputs
comparison_target="out",
tolerance=tolerance,
output_count=2,
description=f"AddRMSNorm - INPLACE(out)",
)
)
# Test Case 3: In-place on first input (add_rms_norm(a, b, w, out=a))
if a_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={
"out": 0,
"epsilon": _EPSILON,
}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"AddRMSNorm - INPLACE(a)",
)
)
# Test Case 4: In-place on second input (add_rms_norm(a, b, w, out=b))
if b_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={
"out": 1,
"epsilon": _EPSILON,
}, # Use index 1 for second input
output_spec=None,
comparison_target=1, # Compare second input
tolerance=tolerance,
description=f"AddRMSNorm - INPLACE(b)",
)
)
return test_cases
......@@ -156,7 +128,7 @@ class OpTest(BaseOperatorTest):
return parse_test_cases()
def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
"""PyTorch AddRMSNorm implementation"""
"""PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)"""
input_dtype = a.dtype
# Compute add(a, b)
......@@ -165,18 +137,27 @@ class OpTest(BaseOperatorTest):
# Calculate RMSNorm: (a + b) * weight / sqrt(mean((a+b)^2) + epsilon)
variance = sum_tensor.pow(2).mean(-1, keepdim=True)
result = sum_tensor * torch.rsqrt(variance + epsilon) * weight_fp32
normalized_result = sum_tensor * torch.rsqrt(variance + epsilon) * weight_fp32
# Convert back to original dtype
result = result.to(input_dtype)
normalized_result = normalized_result.to(input_dtype)
add_result = sum_tensor.to(input_dtype)
if out is not None:
out.copy_(result)
return out
return result
# For in-place operations, we need to handle the output tuple
if isinstance(out, (tuple, list)) and len(out) == 2:
out[0].copy_(normalized_result)
out[1].copy_(add_result)
return tuple(out)
else:
# Single output - just return normalized result for backward compatibility
out.copy_(normalized_result)
return out
return (normalized_result, add_result)
def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
"""InfiniCore AddRMSNorm implementation"""
"""InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)"""
return infinicore.add_rms_norm(a, b, weight, epsilon, out=out)
......
......@@ -91,6 +91,7 @@ def test(
)
y = TestTensor(y_shape, y_stride, dtype, device, mode="ones")
residual_out = TestTensor(a_shape, a_stride, dtype, device, mode="ones")
a = TestTensor(a_shape, a_stride, dtype, device, scale=0.01)
b = TestTensor(b_shape, b_stride, dtype, device, scale=0.01)
w = TestTensor(w_shape, None, w_dtype, device)
......@@ -112,11 +113,12 @@ def test(
b.descriptor,
w.descriptor,
eps,
residual_out.descriptor,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [a, b, y, w]:
for tensor in [a, b, y, w, residual_out]:
tensor.destroy_desc()
workspace_size = c_uint64(0)
......@@ -137,6 +139,7 @@ def test(
a.data(),
b.data(),
w.data(),
residual_out.data(),
None,
)
)
......@@ -144,9 +147,18 @@ def test(
lib_add_rms_norm()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# Verify normalized result (y)
if DEBUG:
debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
# Verify add result (residual_out) - should be a + b
expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(torch.float32)
expected_residual = expected_residual.to(a.torch_tensor().dtype)
if DEBUG:
debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
assert torch.allclose(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment