Commit 0c204dfd authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/791 fix add_rmsnorm api and rmsnorm module

parent f9761a29
#pragma once #pragma once
#include "module.hpp"
#include "../ops.hpp" #include "../ops.hpp"
#include "module.hpp"
namespace infinicore::nn { namespace infinicore::nn {
...@@ -57,6 +57,21 @@ public: ...@@ -57,6 +57,21 @@ public:
*/ */
Tensor forward(const Tensor &x) const; Tensor forward(const Tensor &x) const;
/**
* @brief Forward pass: apply RMSNorm in-place with residual
*
* @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions.
* Will be modified in-place to the normalized output.
* @param residual Residual tensor to add to input before normalization.
* Will be modified in-place to the sum of input and residual.
*
* The normalization is applied over the last dimension.
* For example:
* Input: [batch, seq_len, hidden_size] -> normalize over hidden_size
* Input: [batch, hidden_size] -> normalize over hidden_size
*/
void forward_inplace(Tensor &x, Tensor &residual) const;
// Module information // Module information
size_t normalized_shape() const { return normalized_shape_; } size_t normalized_shape() const { return normalized_shape_; }
double eps() const { return eps_; } double eps() const { return eps_; }
......
...@@ -5,16 +5,14 @@ ...@@ -5,16 +5,14 @@
#include <utility> #include <utility>
namespace infinicore::op { namespace infinicore::op {
class AddRMSNorm { INFINICORE_GRAPH_OP_CLASS(AddRMSNorm, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &, float);
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};
// Fused Add and RMS Normalization // Fused Add and RMS Normalization
// Returns: (normalized_result, add_result) // Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers // The add_result can be used as residual for subsequent layers
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); std::pair<Tensor, Tensor> add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f);
// Fused Add and RMS Normalization (inplace)
// normalized_result wil be stored in input, add_result will be stored in residual
void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon = 1e-5f);
} // namespace infinicore::op } // namespace infinicore::op
...@@ -9,11 +9,11 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor( ...@@ -9,11 +9,11 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr, infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t weight_desc,
float epsilon, float epsilon);
infiniopTensorDescriptor_t residual_out_desc);
__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size); __C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);
...@@ -21,10 +21,10 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de ...@@ -21,10 +21,10 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
void *y, void *y,
void *residual_out,
const void *a, const void *a,
const void *b, const void *b,
const void *weight, const void *weight,
void *residual_out,
void *stream); void *stream);
__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc); __C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);
......
...@@ -43,7 +43,7 @@ from infinicore.dtype import ( ...@@ -43,7 +43,7 @@ from infinicore.dtype import (
uint8, uint8,
) )
from infinicore.ops.add import add from infinicore.ops.add import add
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_ from infinicore.ops.add_rms_norm import add_rms_norm
from infinicore.ops.attention import attention from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul from infinicore.ops.mul import mul
......
import infinicore.tensor as tensor
from infinicore.lib import _infinicore from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None): def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None, residual=None):
""" """
Fused Add and RMS Normalization. Fused Add and RMS Normalization.
...@@ -18,30 +18,17 @@ def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None): ...@@ -18,30 +18,17 @@ def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
The add_result can be used as residual for subsequent layers. The add_result can be used as residual for subsequent layers.
""" """
if out is None: if out is None:
result = _infinicore.add_rms_norm( out = tensor.empty(a.shape, dtype=a.dtype, device=a.device)
a._underlying, b._underlying, weight._underlying, epsilon if residual is None:
) residual = tensor.empty(b.shape, dtype=b.dtype, device=b.device)
return (Tensor(result[0]), Tensor(result[1]))
y, residual_out = out
_infinicore.add_rms_norm_( _infinicore.add_rms_norm_(
y._underlying, out._underlying,
residual_out._underlying, residual._underlying,
a._underlying, a._underlying,
b._underlying, b._underlying,
weight._underlying, weight._underlying,
epsilon, epsilon,
) )
return (y, residual_out)
def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5): return out, residual
"""In-place Fused Add and RMS Normalization."""
_infinicore.add_rms_norm_(
y._underlying,
residual_out._underlying,
a._underlying,
b._underlying,
weight._underlying,
epsilon,
)
...@@ -21,6 +21,24 @@ Tensor RMSNorm::forward(const Tensor &x) const { ...@@ -21,6 +21,24 @@ Tensor RMSNorm::forward(const Tensor &x) const {
return op::rms_norm(x, weight_, static_cast<float>(eps_)); return op::rms_norm(x, weight_, static_cast<float>(eps_));
} }
void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const {
if (!residual) {
residual = x;
x = op::rms_norm(x, weight_, static_cast<float>(eps_));
} else {
if (device_.getType() == Device::Type::CPU
|| device_.getType() == Device::Type::NVIDIA
|| device_.getType() == Device::Type::ILUVATAR
|| device_.getType() == Device::Type::METAX
|| device_.getType() == Device::Type::MOORE) {
op::add_rms_norm_inplace(x, residual, weight_, static_cast<float>(eps_));
} else {
op::add_(residual, x, residual);
op::rms_norm_(x, residual, weight_, static_cast<float>(eps_));
}
}
}
std::string RMSNorm::extra_repr() const { std::string RMSNorm::extra_repr() const {
return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")"; return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
} }
......
...@@ -4,26 +4,30 @@ ...@@ -4,26 +4,30 @@
namespace infinicore::op { namespace infinicore::op {
common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(AddRMSNorm);
static common::OpDispatcher<AddRMSNorm::schema> dispatcher_;
return dispatcher_;
};
void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { AddRMSNorm::AddRMSNorm(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight);
infinicore::context::setDevice(y->device()); INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, residual_out, a, b, weight, epsilon);
dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon);
} }
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) { void AddRMSNorm::execute(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(AddRMSNorm, y, residual_out, a, b, weight, epsilon);
}
std::pair<Tensor, Tensor> add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
auto y = Tensor::empty(a->shape(), a->dtype(), a->device()); auto y = Tensor::empty(a->shape(), a->dtype(), a->device());
auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device()); auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device());
add_rms_norm_(y, residual_out, a, b, weight, epsilon); add_rms_norm_(y, residual_out, a, b, weight, epsilon);
return std::make_pair(y, residual_out); return std::make_pair(y, residual_out);
} }
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon); AddRMSNorm::execute(out, residual, a, b, weight, epsilon);
}
void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon) {
add_rms_norm_(input, residual, input, residual, weight, epsilon);
} }
} // namespace infinicore::op } // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add_rms_norm.hpp" #include "infinicore/ops/add_rms_norm.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h> #include "../infiniop_impl.hpp"
namespace infinicore::op::add_rms_norm_impl::infiniop { namespace infinicore::op::add_rms_norm_impl::infiniop {
thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches( INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, AddRMSNorm, 100);
100, // capacity
[](infiniopAddRMSNormDescriptor_t &desc) { struct PlannedMeta {
if (desc != nullptr) { std::shared_ptr<Descriptor> descriptor;
INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc)); graph::GraphTensor workspace, out, residual, a, b, weight;
desc = nullptr; float epsilon;
} };
});
void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { void *plan(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon); size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon);
auto device = context::getDevice(); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
auto &cache = caches.getCache(device); Descriptor, descriptor, AddRMSNorm,
seed, y->desc(), residual_out->desc(),
a->desc(), b->desc(), weight->desc(), epsilon);
INFINIOP_WORKSPACE_TENSOR(workspace, AddRMSNorm, descriptor);
auto desc_opt = cache.get(seed); auto planned = new PlannedMeta{
infiniopAddRMSNormDescriptor_t desc = nullptr; descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(y),
graph::GraphTensor(residual_out),
graph::GraphTensor(a),
graph::GraphTensor(b),
graph::GraphTensor(weight),
epsilon};
if (!desc_opt) { return planned;
INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor( }
context::getInfiniopHandle(device), &desc,
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0; void run(void *planned_meta) {
INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size)); auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopAddRMSNorm( INFINICORE_CHECK_ERROR(infiniopAddRMSNorm(
desc, workspace->data(), workspace_size, planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(),
y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream())); planned->out->data(), planned->residual->data(), planned->a->data(), planned->b->data(), planned->weight->data(), context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
} }
static bool registered = []() { INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(AddRMSNorm, &plan, &run, &cleanup);
AddRMSNorm::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::add_rms_norm_impl::infiniop } // namespace infinicore::op::add_rms_norm_impl::infiniop
...@@ -33,19 +33,19 @@ ...@@ -33,19 +33,19 @@
infiniopHandle_t handle, \ infiniopHandle_t handle, \
Descriptor **desc_ptr, \ Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \ infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t residual_out_desc, \
infiniopTensorDescriptor_t a_desc, \ infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \ infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \ infiniopTensorDescriptor_t weight_desc, \
float epsilon, \ float epsilon); \
infiniopTensorDescriptor_t residual_out_desc); \
\ \
infiniStatus_t calculate( \ infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \ void *workspace, size_t workspace_size, \
void *y, \ void *y, \
void *residual_out, \
const void *a, \ const void *a, \
const void *b, \ const void *b, \
const void *weight, \ const void *weight, \
void *residual_out, \
void *stream) const; \ void *stream) const; \
}; \ }; \
} }
......
...@@ -10,19 +10,19 @@ infiniStatus_t Descriptor::create( ...@@ -10,19 +10,19 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle, infiniopHandle_t handle,
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t weight_desc,
float epsilon, float epsilon) {
infiniopTensorDescriptor_t residual_out_desc) { auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result); CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
template <typename T> template <typename T>
infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w, T *residual_out) { infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const T *w) {
const size_t batch_size = info->shape[0]; const size_t batch_size = info->shape[0];
const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1; const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1;
const size_t dim = info->dim(); const size_t dim = info->dim();
...@@ -61,7 +61,7 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T ...@@ -61,7 +61,7 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
} }
template <typename T, typename Tw> template <typename T, typename Tw>
infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w, T *residual_out) { infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const Tw *w) {
static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value, static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value,
"T must be fp16_t or bf16_t"); "T must be fp16_t or bf16_t");
...@@ -112,32 +112,32 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const ...@@ -112,32 +112,32 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight, void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *residual_out, void *stream) const { void *stream) const {
if (_info.atype == INFINI_DTYPE_F16) { if (_info.atype == INFINI_DTYPE_F16) {
if (_info.wtype == INFINI_DTYPE_F16) { if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight, (fp16_t *)residual_out)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight));
} else if (_info.wtype == INFINI_DTYPE_F32) { } else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight, (fp16_t *)residual_out)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight));
} else if (_info.wtype == INFINI_DTYPE_BF16) { } else if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight, (fp16_t *)residual_out)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight));
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
} else if (_info.atype == INFINI_DTYPE_BF16) { } else if (_info.atype == INFINI_DTYPE_BF16) {
if (_info.wtype == INFINI_DTYPE_BF16) { if (_info.wtype == INFINI_DTYPE_BF16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight, (bf16_t *)residual_out)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight));
} else if (_info.wtype == INFINI_DTYPE_F32) { } else if (_info.wtype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight, (bf16_t *)residual_out)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight));
} else if (_info.wtype == INFINI_DTYPE_F16) { } else if (_info.wtype == INFINI_DTYPE_F16) {
CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight, (bf16_t *)residual_out)); CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight));
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
} else if (_info.atype == INFINI_DTYPE_F32) { } else if (_info.atype == INFINI_DTYPE_F32) {
CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight, (float *)residual_out)); CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (float *)residual_out, (const float *)a, (const float *)b, (const float *)weight));
} else if (_info.atype == INFINI_DTYPE_F64) { } else if (_info.atype == INFINI_DTYPE_F64) {
CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight, (double *)residual_out)); CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (double *)residual_out, (const double *)a, (const double *)b, (const double *)weight));
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
...@@ -16,9 +16,9 @@ public: ...@@ -16,9 +16,9 @@ public:
float epsilon; float epsilon;
std::vector<size_t> shape; std::vector<size_t> shape;
std::vector<ptrdiff_t> y_strides; std::vector<ptrdiff_t> y_strides;
std::vector<ptrdiff_t> residual_out_strides;
std::vector<ptrdiff_t> a_strides; std::vector<ptrdiff_t> a_strides;
std::vector<ptrdiff_t> b_strides; std::vector<ptrdiff_t> b_strides;
std::vector<ptrdiff_t> residual_out_strides;
bool has_residual_out; bool has_residual_out;
size_t ndim() const { return shape.size(); } size_t ndim() const { return shape.size(); }
...@@ -26,11 +26,11 @@ public: ...@@ -26,11 +26,11 @@ public:
static utils::Result<AddRMSNormInfo> create( static utils::Result<AddRMSNormInfo> create(
infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t weight_desc,
float epsilon, float epsilon) {
infiniopTensorDescriptor_t residual_out_desc) {
auto atype = y_desc->dtype(); auto atype = y_desc->dtype();
auto wtype = weight_desc->dtype(); auto wtype = weight_desc->dtype();
......
...@@ -49,12 +49,12 @@ infiniStatus_t Descriptor::create( ...@@ -49,12 +49,12 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle, infiniopHandle_t handle,
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t weight_desc,
float epsilon, float epsilon) {
infiniopTensorDescriptor_t residual_out_desc) { auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result); CHECK_RESULT(result);
auto info = result.take(); auto info = result.take();
...@@ -122,8 +122,8 @@ infiniStatus_t launchKernel( ...@@ -122,8 +122,8 @@ infiniStatus_t launchKernel(
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight, void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *residual_out, void *stream) const { void *stream) const {
if (workspace_size < _workspace_size) { if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
......
...@@ -32,12 +32,12 @@ ...@@ -32,12 +32,12 @@
__C infiniStatus_t infiniopCreateAddRMSNormDescriptor( __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr, infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t weight_desc,
float epsilon, float epsilon) {
infiniopTensorDescriptor_t residual_out_desc) {
#define CREATE(CASE, NAMESPACE) \ #define CREATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
...@@ -45,11 +45,11 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( ...@@ -45,11 +45,11 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
handle, \ handle, \
reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \ reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \ y_desc, \
residual_out_desc, \
a_desc, \ a_desc, \
b_desc, \ b_desc, \
weight_desc, \ weight_desc, \
epsilon, \ epsilon)
residual_out_desc)
switch (handle->device) { switch (handle->device) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
...@@ -116,16 +116,16 @@ __C infiniStatus_t infiniopAddRMSNorm( ...@@ -116,16 +116,16 @@ __C infiniStatus_t infiniopAddRMSNorm(
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
void *y, void *y,
void *residual_out,
const void *a, const void *a,
const void *b, const void *b,
const void *weight, const void *weight,
void *residual_out,
void *stream) { void *stream) {
#define CALCULATE(CASE, NAMESPACE) \ #define CALCULATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \ return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream) ->calculate(workspace, workspace_size, y, residual_out, a, b, weight, stream)
switch (desc->device_type) { switch (desc->device_type) {
......
...@@ -30,8 +30,24 @@ _TEST_CASES_DATA = [ ...@@ -30,8 +30,24 @@ _TEST_CASES_DATA = [
((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)), ((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None), ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)), (
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)), (4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(2048, 8192, 1),
(2048, 8192, 1),
(2048, 8192, 1),
),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(16384, 4096, 1),
(16384, 4096, 1),
(16384, 4096, 1),
),
] ]
# Tolerance configuration # Tolerance configuration
...@@ -87,12 +103,14 @@ def parse_test_cases(): ...@@ -87,12 +103,14 @@ def parse_test_cases():
y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype) y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype)
# Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result) # Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result)
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype) residual_out_spec = TensorSpec.from_tensor(
a_shape, a_strides, input_dtype
)
test_cases.append( test_cases.append(
TestCase( TestCase(
inputs=[a_spec, b_spec, w_spec], inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON}, kwargs={"epsilon": _EPSILON},
output_specs=[y_spec, residual_out_spec], # Two outputs output_specs=None, # Two outputs
comparison_target=None, comparison_target=None,
tolerance=tolerance, tolerance=tolerance,
output_count=2, # Two outputs: normalized_result and add_result output_count=2, # Two outputs: normalized_result and add_result
...@@ -101,19 +119,25 @@ def parse_test_cases(): ...@@ -101,19 +119,25 @@ def parse_test_cases():
) )
# Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w)) # Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w))
if y_supports_inplace: # if y_supports_inplace:
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype) # residual_out_spec = TensorSpec.from_tensor(
test_cases.append( # a_shape, a_strides, input_dtype
TestCase( # )
inputs=[a_spec, b_spec, w_spec], # test_cases.append(
kwargs={"epsilon": _EPSILON, "out": (y_spec, residual_out_spec)}, # TestCase(
output_specs=[y_spec, residual_out_spec], # Two outputs # inputs=[a_spec, b_spec, w_spec],
comparison_target="out", # kwargs={
tolerance=tolerance, # "epsilon": _EPSILON,
output_count=2, # "out": y_spec,
description=f"AddRMSNorm - INPLACE(out)", # "residual": residual_out_spec,
) # },
) # output_specs=[y_spec, residual_out_spec], # Two outputs
# comparison_target="out",
# tolerance=tolerance,
# output_count=2,
# description=f"AddRMSNorm - INPLACE(out)",
# )
# )
return test_cases return test_cases
...@@ -127,7 +151,9 @@ class OpTest(BaseOperatorTest): ...@@ -127,7 +151,9 @@ class OpTest(BaseOperatorTest):
def get_test_cases(self): def get_test_cases(self):
return parse_test_cases() return parse_test_cases()
def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): def torch_operator(
self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs
):
"""PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)""" """PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)"""
input_dtype = a.dtype input_dtype = a.dtype
...@@ -144,21 +170,19 @@ class OpTest(BaseOperatorTest): ...@@ -144,21 +170,19 @@ class OpTest(BaseOperatorTest):
add_result = sum_tensor.to(input_dtype) add_result = sum_tensor.to(input_dtype)
if out is not None: if out is not None:
# For in-place operations, we need to handle the output tuple
if isinstance(out, (tuple, list)) and len(out) == 2:
out[0].copy_(normalized_result)
out[1].copy_(add_result)
return tuple(out)
else:
# Single output - just return normalized result for backward compatibility
out.copy_(normalized_result) out.copy_(normalized_result)
return out if residual is not None:
residual.copy_(add_result)
return (normalized_result, add_result) return (normalized_result, add_result)
def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): def infinicore_operator(
self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs
):
"""InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)""" """InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)"""
return infinicore.add_rms_norm(a, b, weight, epsilon, out=out) return infinicore.add_rms_norm(
a, b, weight, epsilon, out=out, residual=residual
)
def main(): def main():
......
...@@ -32,8 +32,24 @@ _TEST_CASES_ = [ ...@@ -32,8 +32,24 @@ _TEST_CASES_ = [
((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)), ((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None), ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)), (
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)), (4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(2048, 8192, 1),
(2048, 8192, 1),
(2048, 8192, 1),
),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(16384, 4096, 1),
(16384, 4096, 1),
(16384, 4096, 1),
),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((15, 8192), (15, 8192), (15, 8192), (8192,), None, None, None), ((15, 8192), (15, 8192), (15, 8192), (8192,), None, None, None),
] ]
...@@ -97,7 +113,9 @@ def test( ...@@ -97,7 +113,9 @@ def test(
w = TestTensor(w_shape, None, w_dtype, device) w = TestTensor(w_shape, None, w_dtype, device)
eps = 1e-6 eps = 1e-6
add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps) add_rms_norm(
y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps
)
if sync is not None: if sync is not None:
sync() sync()
...@@ -109,11 +127,11 @@ def test( ...@@ -109,11 +127,11 @@ def test(
handle, handle,
ctypes.byref(descriptor), ctypes.byref(descriptor),
y.descriptor, y.descriptor,
residual_out.descriptor,
a.descriptor, a.descriptor,
b.descriptor, b.descriptor,
w.descriptor, w.descriptor,
eps, eps,
residual_out.descriptor,
) )
) )
...@@ -136,10 +154,10 @@ def test( ...@@ -136,10 +154,10 @@ def test(
workspace.data(), workspace.data(),
workspace_size.value, workspace_size.value,
y.data(), y.data(),
residual_out.data(),
a.data(), a.data(),
b.data(), b.data(),
w.data(), w.data(),
residual_out.data(),
None, None,
) )
) )
...@@ -154,11 +172,15 @@ def test( ...@@ -154,11 +172,15 @@ def test(
assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
# Verify add result (residual_out) - should be a + b # Verify add result (residual_out) - should be a + b
expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(torch.float32) expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(
torch.float32
)
expected_residual = expected_residual.to(a.torch_tensor().dtype) expected_residual = expected_residual.to(a.torch_tensor().dtype)
if DEBUG: if DEBUG:
debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol) debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
assert torch.allclose(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol) assert torch.allclose(
residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol
)
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
......
...@@ -393,6 +393,7 @@ def add_rms_norm_(lib): ...@@ -393,6 +393,7 @@ def add_rms_norm_(lib):
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float, c_float,
] ]
...@@ -412,6 +413,7 @@ def add_rms_norm_(lib): ...@@ -412,6 +413,7 @@ def add_rms_norm_(lib):
c_void_p, c_void_p,
c_void_p, c_void_p,
c_void_p, c_void_p,
c_void_p,
] ]
lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32 lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment