Commit 0c204dfd authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/791 fix add_rmsnorm api and rmsnorm module

parent f9761a29
#pragma once
#include "module.hpp"
#include "../ops.hpp"
#include "module.hpp"
namespace infinicore::nn {
......@@ -57,6 +57,21 @@ public:
*/
Tensor forward(const Tensor &x) const;
/**
* @brief Forward pass: apply RMSNorm in-place with residual
*
* @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions.
* Will be modified in-place to the normalized output.
* @param residual Residual tensor to add to input before normalization.
* Will be modified in-place to the sum of input and residual.
*
* The normalization is applied over the last dimension.
* For example:
* Input: [batch, seq_len, hidden_size] -> normalize over hidden_size
* Input: [batch, hidden_size] -> normalize over hidden_size
*/
void forward_inplace(Tensor &x, Tensor &residual) const;
// Module information
size_t normalized_shape() const { return normalized_shape_; }
double eps() const { return eps_; }
......@@ -73,9 +88,9 @@ protected:
INFINICORE_NN_PARAMETER(weight);
private:
size_t normalized_shape_; // Size of the feature dimension
double eps_; // Epsilon for numerical stability
DataType dtype_; // Data type for weight
size_t normalized_shape_; // Size of the feature dimension
double eps_; // Epsilon for numerical stability
DataType dtype_; // Data type for weight
};
} // namespace infinicore::nn
......@@ -5,16 +5,14 @@
#include <utility>
namespace infinicore::op {
class AddRMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(AddRMSNorm, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &, float);
// Fused Add and RMS Normalization
// Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
std::pair<Tensor, Tensor> add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f);
// Fused Add and RMS Normalization (inplace)
// normalized_result wil be stored in input, add_result will be stored in residual
void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon = 1e-5f);
} // namespace infinicore::op
......@@ -9,11 +9,11 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc);
float epsilon);
__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);
......@@ -21,10 +21,10 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de
void *workspace,
size_t workspace_size,
void *y,
void *residual_out,
const void *a,
const void *b,
const void *weight,
void *residual_out,
void *stream);
__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);
......
......@@ -43,7 +43,7 @@ from infinicore.dtype import (
uint8,
)
from infinicore.ops.add import add
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
from infinicore.ops.add_rms_norm import add_rms_norm
from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
......
import infinicore.tensor as tensor
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None, residual=None):
"""
Fused Add and RMS Normalization.
......@@ -18,30 +18,17 @@ def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
The add_result can be used as residual for subsequent layers.
"""
if out is None:
result = _infinicore.add_rms_norm(
a._underlying, b._underlying, weight._underlying, epsilon
)
return (Tensor(result[0]), Tensor(result[1]))
out = tensor.empty(a.shape, dtype=a.dtype, device=a.device)
if residual is None:
residual = tensor.empty(b.shape, dtype=b.dtype, device=b.device)
y, residual_out = out
_infinicore.add_rms_norm_(
y._underlying,
residual_out._underlying,
out._underlying,
residual._underlying,
a._underlying,
b._underlying,
weight._underlying,
epsilon,
)
return (y, residual_out)
def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5):
"""In-place Fused Add and RMS Normalization."""
_infinicore.add_rms_norm_(
y._underlying,
residual_out._underlying,
a._underlying,
b._underlying,
weight._underlying,
epsilon,
)
return out, residual
......@@ -21,6 +21,24 @@ Tensor RMSNorm::forward(const Tensor &x) const {
return op::rms_norm(x, weight_, static_cast<float>(eps_));
}
void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const {
if (!residual) {
residual = x;
x = op::rms_norm(x, weight_, static_cast<float>(eps_));
} else {
if (device_.getType() == Device::Type::CPU
|| device_.getType() == Device::Type::NVIDIA
|| device_.getType() == Device::Type::ILUVATAR
|| device_.getType() == Device::Type::METAX
|| device_.getType() == Device::Type::MOORE) {
op::add_rms_norm_inplace(x, residual, weight_, static_cast<float>(eps_));
} else {
op::add_(residual, x, residual);
op::rms_norm_(x, residual, weight_, static_cast<float>(eps_));
}
}
}
std::string RMSNorm::extra_repr() const {
return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
}
......
......@@ -4,26 +4,30 @@
namespace infinicore::op {
common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() {
static common::OpDispatcher<AddRMSNorm::schema> dispatcher_;
return dispatcher_;
};
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(AddRMSNorm);
void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
AddRMSNorm::AddRMSNorm(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight);
infinicore::context::setDevice(y->device());
dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon);
INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, residual_out, a, b, weight, epsilon);
}
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) {
void AddRMSNorm::execute(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(AddRMSNorm, y, residual_out, a, b, weight, epsilon);
}
std::pair<Tensor, Tensor> add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
auto y = Tensor::empty(a->shape(), a->dtype(), a->device());
auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device());
add_rms_norm_(y, residual_out, a, b, weight, epsilon);
return std::make_pair(y, residual_out);
}
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon);
void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
AddRMSNorm::execute(out, residual, a, b, weight, epsilon);
}
void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon) {
add_rms_norm_(input, residual, input, residual, weight, epsilon);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add_rms_norm.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>
#include "../infiniop_impl.hpp"
namespace infinicore::op::add_rms_norm_impl::infiniop {
thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches(
100, // capacity
[](infiniopAddRMSNormDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc));
desc = nullptr;
}
});
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, AddRMSNorm, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, out, residual, a, b, weight;
float epsilon;
};
void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
void *plan(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, AddRMSNorm,
seed, y->desc(), residual_out->desc(),
a->desc(), b->desc(), weight->desc(), epsilon);
INFINIOP_WORKSPACE_TENSOR(workspace, AddRMSNorm, descriptor);
auto desc_opt = cache.get(seed);
infiniopAddRMSNormDescriptor_t desc = nullptr;
auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(y),
graph::GraphTensor(residual_out),
graph::GraphTensor(a),
graph::GraphTensor(b),
graph::GraphTensor(weight),
epsilon};
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor(
context::getInfiniopHandle(device), &desc,
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
return planned;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopAddRMSNorm(
desc, workspace->data(), workspace_size,
y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream()));
planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(),
planned->out->data(), planned->residual->data(), planned->a->data(), planned->b->data(), planned->weight->data(), context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
static bool registered = []() {
AddRMSNorm::dispatcher().registerAll(&calculate, false);
return true;
}();
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(AddRMSNorm, &plan, &run, &cleanup);
} // namespace infinicore::op::add_rms_norm_impl::infiniop
......@@ -33,19 +33,19 @@
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t residual_out_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
float epsilon); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
void *residual_out, \
const void *a, \
const void *b, \
const void *weight, \
void *residual_out, \
void *stream) const; \
}; \
}
......
......@@ -10,19 +10,19 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename T>
infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w, T *residual_out) {
infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const T *w) {
const size_t batch_size = info->shape[0];
const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1;
const size_t dim = info->dim();
......@@ -61,7 +61,7 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
}
template <typename T, typename Tw>
infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w, T *residual_out) {
infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const Tw *w) {
static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value,
"T must be fp16_t or bf16_t");
......@@ -112,32 +112,32 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *residual_out, void *stream) const {
void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight, (fp16_t *)residual_out));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight, (fp16_t *)residual_out));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight));
} else if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight, (fp16_t *)residual_out));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_BF16) {
if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight, (bf16_t *)residual_out));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight));
} else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight, (bf16_t *)residual_out));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight));
} else if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight, (bf16_t *)residual_out));
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight, (float *)residual_out));
CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (float *)residual_out, (const float *)a, (const float *)b, (const float *)weight));
} else if (_info.atype == INFINI_DTYPE_F64) {
CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight, (double *)residual_out));
CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (double *)residual_out, (const double *)a, (const double *)b, (const double *)weight));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -16,9 +16,9 @@ public:
float epsilon;
std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> residual_out_strides;
std::vector<ptrdiff_t> a_strides;
std::vector<ptrdiff_t> b_strides;
std::vector<ptrdiff_t> residual_out_strides;
bool has_residual_out;
size_t ndim() const { return shape.size(); }
......@@ -26,11 +26,11 @@ public:
static utils::Result<AddRMSNormInfo> create(
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
float epsilon) {
auto atype = y_desc->dtype();
auto wtype = weight_desc->dtype();
......
......@@ -49,12 +49,12 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
......@@ -122,8 +122,8 @@ infiniStatus_t launchKernel(
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *residual_out, void *stream) const {
void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
......
......@@ -32,12 +32,12 @@
__C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
float epsilon) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
......@@ -45,11 +45,11 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
handle, \
reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
residual_out_desc, \
a_desc, \
b_desc, \
weight_desc, \
epsilon, \
residual_out_desc)
epsilon)
switch (handle->device) {
#ifdef ENABLE_CPU_API
......@@ -116,16 +116,16 @@ __C infiniStatus_t infiniopAddRMSNorm(
void *workspace,
size_t workspace_size,
void *y,
void *residual_out,
const void *a,
const void *b,
const void *weight,
void *residual_out,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream)
->calculate(workspace, workspace_size, y, residual_out, a, b, weight, stream)
switch (desc->device_type) {
......
......@@ -30,8 +30,24 @@ _TEST_CASES_DATA = [
((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(2048, 8192, 1),
(2048, 8192, 1),
(2048, 8192, 1),
),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(16384, 4096, 1),
(16384, 4096, 1),
(16384, 4096, 1),
),
]
# Tolerance configuration
......@@ -87,12 +103,14 @@ def parse_test_cases():
y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype)
# Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result)
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
residual_out_spec = TensorSpec.from_tensor(
a_shape, a_strides, input_dtype
)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON},
output_specs=[y_spec, residual_out_spec], # Two outputs
output_specs=None, # Two outputs
comparison_target=None,
tolerance=tolerance,
output_count=2, # Two outputs: normalized_result and add_result
......@@ -101,19 +119,25 @@ def parse_test_cases():
)
# Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w))
if y_supports_inplace:
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON, "out": (y_spec, residual_out_spec)},
output_specs=[y_spec, residual_out_spec], # Two outputs
comparison_target="out",
tolerance=tolerance,
output_count=2,
description=f"AddRMSNorm - INPLACE(out)",
)
)
# if y_supports_inplace:
# residual_out_spec = TensorSpec.from_tensor(
# a_shape, a_strides, input_dtype
# )
# test_cases.append(
# TestCase(
# inputs=[a_spec, b_spec, w_spec],
# kwargs={
# "epsilon": _EPSILON,
# "out": y_spec,
# "residual": residual_out_spec,
# },
# output_specs=[y_spec, residual_out_spec], # Two outputs
# comparison_target="out",
# tolerance=tolerance,
# output_count=2,
# description=f"AddRMSNorm - INPLACE(out)",
# )
# )
return test_cases
......@@ -127,7 +151,9 @@ class OpTest(BaseOperatorTest):
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
def torch_operator(
self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs
):
"""PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)"""
input_dtype = a.dtype
......@@ -144,21 +170,19 @@ class OpTest(BaseOperatorTest):
add_result = sum_tensor.to(input_dtype)
if out is not None:
# For in-place operations, we need to handle the output tuple
if isinstance(out, (tuple, list)) and len(out) == 2:
out[0].copy_(normalized_result)
out[1].copy_(add_result)
return tuple(out)
else:
# Single output - just return normalized result for backward compatibility
out.copy_(normalized_result)
return out
out.copy_(normalized_result)
if residual is not None:
residual.copy_(add_result)
return (normalized_result, add_result)
def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
def infinicore_operator(
self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs
):
"""InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)"""
return infinicore.add_rms_norm(a, b, weight, epsilon, out=out)
return infinicore.add_rms_norm(
a, b, weight, epsilon, out=out, residual=residual
)
def main():
......
......@@ -32,8 +32,24 @@ _TEST_CASES_ = [
((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(2048, 8192, 1),
(2048, 8192, 1),
(2048, 8192, 1),
),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(16384, 4096, 1),
(16384, 4096, 1),
(16384, 4096, 1),
),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((15, 8192), (15, 8192), (15, 8192), (8192,), None, None, None),
]
......@@ -97,7 +113,9 @@ def test(
w = TestTensor(w_shape, None, w_dtype, device)
eps = 1e-6
add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps)
add_rms_norm(
y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps
)
if sync is not None:
sync()
......@@ -109,11 +127,11 @@ def test(
handle,
ctypes.byref(descriptor),
y.descriptor,
residual_out.descriptor,
a.descriptor,
b.descriptor,
w.descriptor,
eps,
residual_out.descriptor,
)
)
......@@ -136,10 +154,10 @@ def test(
workspace.data(),
workspace_size.value,
y.data(),
residual_out.data(),
a.data(),
b.data(),
w.data(),
residual_out.data(),
None,
)
)
......@@ -147,18 +165,22 @@ def test(
lib_add_rms_norm()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# Verify normalized result (y)
if DEBUG:
debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
# Verify add result (residual_out) - should be a + b
expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(torch.float32)
expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(
torch.float32
)
expected_residual = expected_residual.to(a.torch_tensor().dtype)
if DEBUG:
debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
assert torch.allclose(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
assert torch.allclose(
residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol
)
# Profiling workflow
if PROFILE:
......
......@@ -393,6 +393,7 @@ def add_rms_norm_(lib):
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
]
......@@ -412,6 +413,7 @@ def add_rms_norm_(lib):
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment