Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
0c204dfd
Commit
0c204dfd
authored
Jan 23, 2026
by
PanZezhong
Committed by
wooway777
Jan 27, 2026
Browse files
issue/791 fix add_rmsnorm api and rmsnorm module
parent
f9761a29
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
225 additions
and
152 deletions
+225
-152
include/infinicore/nn/rmsnorm.hpp
include/infinicore/nn/rmsnorm.hpp
+19
-4
include/infinicore/ops/add_rms_norm.hpp
include/infinicore/ops/add_rms_norm.hpp
+6
-8
include/infiniop/ops/add_rms_norm.h
include/infiniop/ops/add_rms_norm.h
+3
-3
python/infinicore/__init__.py
python/infinicore/__init__.py
+1
-1
python/infinicore/ops/add_rms_norm.py
python/infinicore/ops/add_rms_norm.py
+8
-21
src/infinicore/nn/rmsnorm.cc
src/infinicore/nn/rmsnorm.cc
+18
-0
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
+14
-10
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
+37
-34
src/infiniop/ops/add_rms_norm/add_rms_norm.h
src/infiniop/ops/add_rms_norm/add_rms_norm.h
+3
-3
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
+15
-15
src/infiniop/ops/add_rms_norm/info.h
src/infiniop/ops/add_rms_norm/info.h
+3
-3
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
+5
-5
src/infiniop/ops/add_rms_norm/operator.cc
src/infiniop/ops/add_rms_norm/operator.cc
+6
-6
test/infinicore/ops/add_rms_norm.py
test/infinicore/ops/add_rms_norm.py
+54
-30
test/infiniop/add_rms_norm.py
test/infiniop/add_rms_norm.py
+31
-9
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+2
-0
No files found.
include/infinicore/nn/rmsnorm.hpp
View file @
0c204dfd
#pragma once
#include "module.hpp"
#include "../ops.hpp"
#include "module.hpp"
namespace
infinicore
::
nn
{
...
...
@@ -57,6 +57,21 @@ public:
*/
Tensor
forward
(
const
Tensor
&
x
)
const
;
/**
* @brief Forward pass: apply RMSNorm in-place with residual
*
* @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions.
* Will be modified in-place to the normalized output.
* @param residual Residual tensor to add to input before normalization.
* Will be modified in-place to the sum of input and residual.
*
* The normalization is applied over the last dimension.
* For example:
* Input: [batch, seq_len, hidden_size] -> normalize over hidden_size
* Input: [batch, hidden_size] -> normalize over hidden_size
*/
void
forward_inplace
(
Tensor
&
x
,
Tensor
&
residual
)
const
;
// Module information
size_t
normalized_shape
()
const
{
return
normalized_shape_
;
}
double
eps
()
const
{
return
eps_
;
}
...
...
@@ -73,9 +88,9 @@ protected:
INFINICORE_NN_PARAMETER
(
weight
);
private:
size_t
normalized_shape_
;
// Size of the feature dimension
double
eps_
;
// Epsilon for numerical stability
DataType
dtype_
;
// Data type for weight
size_t
normalized_shape_
;
// Size of the feature dimension
double
eps_
;
// Epsilon for numerical stability
DataType
dtype_
;
// Data type for weight
};
}
// namespace infinicore::nn
include/infinicore/ops/add_rms_norm.hpp
View file @
0c204dfd
...
...
@@ -5,16 +5,14 @@
#include <utility>
namespace
infinicore
::
op
{
class
AddRMSNorm
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
float
);
static
void
execute
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
INFINICORE_GRAPH_OP_CLASS
(
AddRMSNorm
,
Tensor
,
Tensor
,
const
Tensor
&
,
const
Tensor
&
,
const
Tensor
&
,
float
);
// Fused Add and RMS Normalization
// Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers
std
::
pair
<
Tensor
,
Tensor
>
add_rms_norm
(
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
void
add_rms_norm_
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
std
::
pair
<
Tensor
,
Tensor
>
add_rms_norm
(
const
Tensor
&
a
,
const
Tensor
&
b
,
const
Tensor
&
weight
,
float
epsilon
=
1e-5
f
);
void
add_rms_norm_
(
Tensor
out
,
Tensor
residual
,
const
Tensor
&
a
,
const
Tensor
&
b
,
const
Tensor
&
weight
,
float
epsilon
=
1e-5
f
);
// Fused Add and RMS Normalization (inplace)
// normalized_result wil be stored in input, add_result will be stored in residual
void
add_rms_norm_inplace
(
Tensor
input
,
Tensor
residual
,
const
Tensor
&
weight
,
float
epsilon
=
1e-5
f
);
}
// namespace infinicore::op
include/infiniop/ops/add_rms_norm.h
View file @
0c204dfd
...
...
@@ -9,11 +9,11 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t
handle
,
infiniopAddRMSNormDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
residual_out_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
);
float
epsilon
);
__C
__export
infiniStatus_t
infiniopGetAddRMSNormWorkspaceSize
(
infiniopAddRMSNormDescriptor_t
desc
,
size_t
*
size
);
...
...
@@ -21,10 +21,10 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
void
*
residual_out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyAddRMSNormDescriptor
(
infiniopAddRMSNormDescriptor_t
desc
);
...
...
python/infinicore/__init__.py
View file @
0c204dfd
...
...
@@ -43,7 +43,7 @@ from infinicore.dtype import (
uint8
,
)
from
infinicore.ops.add
import
add
from
infinicore.ops.add_rms_norm
import
add_rms_norm
,
add_rms_norm_
from
infinicore.ops.add_rms_norm
import
add_rms_norm
from
infinicore.ops.attention
import
attention
from
infinicore.ops.matmul
import
matmul
from
infinicore.ops.mul
import
mul
...
...
python/infinicore/ops/add_rms_norm.py
View file @
0c204dfd
import
infinicore.tensor
as
tensor
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
def
add_rms_norm
(
a
,
b
,
weight
,
epsilon
=
1e-5
,
*
,
out
=
None
):
def
add_rms_norm
(
a
,
b
,
weight
,
epsilon
=
1e-5
,
*
,
out
=
None
,
residual
=
None
):
"""
Fused Add and RMS Normalization.
...
...
@@ -18,30 +18,17 @@ def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
The add_result can be used as residual for subsequent layers.
"""
if
out
is
None
:
result
=
_infinicore
.
add_rms_norm
(
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
)
return
(
Tensor
(
result
[
0
]),
Tensor
(
result
[
1
]))
out
=
tensor
.
empty
(
a
.
shape
,
dtype
=
a
.
dtype
,
device
=
a
.
device
)
if
residual
is
None
:
residual
=
tensor
.
empty
(
b
.
shape
,
dtype
=
b
.
dtype
,
device
=
b
.
device
)
y
,
residual_out
=
out
_infinicore
.
add_rms_norm_
(
y
.
_underlying
,
residual
_out
.
_underlying
,
out
.
_underlying
,
residual
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
,
)
return
(
y
,
residual_out
)
def
add_rms_norm_
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
=
1e-5
):
"""In-place Fused Add and RMS Normalization."""
_infinicore
.
add_rms_norm_
(
y
.
_underlying
,
residual_out
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
,
)
return
out
,
residual
src/infinicore/nn/rmsnorm.cc
View file @
0c204dfd
...
...
@@ -21,6 +21,24 @@ Tensor RMSNorm::forward(const Tensor &x) const {
return
op
::
rms_norm
(
x
,
weight_
,
static_cast
<
float
>
(
eps_
));
}
void
RMSNorm
::
forward_inplace
(
Tensor
&
x
,
Tensor
&
residual
)
const
{
if
(
!
residual
)
{
residual
=
x
;
x
=
op
::
rms_norm
(
x
,
weight_
,
static_cast
<
float
>
(
eps_
));
}
else
{
if
(
device_
.
getType
()
==
Device
::
Type
::
CPU
||
device_
.
getType
()
==
Device
::
Type
::
NVIDIA
||
device_
.
getType
()
==
Device
::
Type
::
ILUVATAR
||
device_
.
getType
()
==
Device
::
Type
::
METAX
||
device_
.
getType
()
==
Device
::
Type
::
MOORE
)
{
op
::
add_rms_norm_inplace
(
x
,
residual
,
weight_
,
static_cast
<
float
>
(
eps_
));
}
else
{
op
::
add_
(
residual
,
x
,
residual
);
op
::
rms_norm_
(
x
,
residual
,
weight_
,
static_cast
<
float
>
(
eps_
));
}
}
}
std
::
string
RMSNorm
::
extra_repr
()
const
{
return
"RMSNorm(normalized_shape="
+
std
::
to_string
(
normalized_shape_
)
+
", eps="
+
std
::
to_string
(
eps_
)
+
", dtype="
+
std
::
to_string
(
static_cast
<
int
>
(
dtype_
))
+
")"
;
}
...
...
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
View file @
0c204dfd
...
...
@@ -4,26 +4,30 @@
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
AddRMSNorm
::
schema
>
&
AddRMSNorm
::
dispatcher
()
{
static
common
::
OpDispatcher
<
AddRMSNorm
::
schema
>
dispatcher_
;
return
dispatcher_
;
};
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL
(
AddRMSNorm
);
void
AddRMSNorm
::
execute
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
AddRMSNorm
::
AddRMSNorm
(
Tensor
y
,
Tensor
residual_out
,
const
Tensor
&
a
,
const
Tensor
&
b
,
const
Tensor
&
weight
,
float
epsilon
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
y
,
residual_out
,
a
,
b
,
weight
);
infinicore
::
context
::
setDevice
(
y
->
device
());
dispatcher
().
lookup
(
y
->
device
().
getType
())(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
INFINICORE_GRAPH_OP_DISPATCH
(
y
->
device
().
getType
(),
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
}
std
::
pair
<
Tensor
,
Tensor
>
add_rms_norm
(
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
void
AddRMSNorm
::
execute
(
Tensor
y
,
Tensor
residual_out
,
const
Tensor
&
a
,
const
Tensor
&
b
,
const
Tensor
&
weight
,
float
epsilon
)
{
INFINICORE_GRAPH_OP_RECORD_OR_RUN
(
AddRMSNorm
,
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
}
std
::
pair
<
Tensor
,
Tensor
>
add_rms_norm
(
const
Tensor
&
a
,
const
Tensor
&
b
,
const
Tensor
&
weight
,
float
epsilon
)
{
auto
y
=
Tensor
::
empty
(
a
->
shape
(),
a
->
dtype
(),
a
->
device
());
auto
residual_out
=
Tensor
::
empty
(
a
->
shape
(),
a
->
dtype
(),
a
->
device
());
add_rms_norm_
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
return
std
::
make_pair
(
y
,
residual_out
);
}
void
add_rms_norm_
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
AddRMSNorm
::
execute
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
void
add_rms_norm_
(
Tensor
out
,
Tensor
residual
,
const
Tensor
&
a
,
const
Tensor
&
b
,
const
Tensor
&
weight
,
float
epsilon
)
{
AddRMSNorm
::
execute
(
out
,
residual
,
a
,
b
,
weight
,
epsilon
);
}
void
add_rms_norm_inplace
(
Tensor
input
,
Tensor
residual
,
const
Tensor
&
weight
,
float
epsilon
)
{
add_rms_norm_
(
input
,
residual
,
input
,
residual
,
weight
,
epsilon
);
}
}
// namespace infinicore::op
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
View file @
0c204dfd
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add_rms_norm.hpp"
#include "infinicore/ops/common/cache.hpp"
#include
<
infiniop
.h>
#include
"../
infiniop
_impl.hpp"
namespace
infinicore
::
op
::
add_rms_norm_impl
::
infiniop
{
thread_local
common
::
OpCache
<
size_t
,
infiniopAddRMSNormDescriptor_t
>
caches
(
100
,
// capacity
[](
infiniopAddRMSNormDescriptor_t
&
desc
)
{
if
(
desc
!=
nullptr
)
{
INFINICORE_CHECK_ERROR
(
infiniopDestroyAddRMSNormDescriptor
(
desc
));
desc
=
nullptr
;
}
});
INFINIOP_CACHABLE_DESCRIPTOR
(
Descriptor
,
AddRMSNorm
,
100
);
struct
PlannedMeta
{
std
::
shared_ptr
<
Descriptor
>
descriptor
;
graph
::
GraphTensor
workspace
,
out
,
residual
,
a
,
b
,
weight
;
float
epsilon
;
};
void
calculate
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
void
*
plan
(
Tensor
y
,
Tensor
residual_out
,
const
Tensor
&
a
,
const
Tensor
&
b
,
const
Tensor
&
weight
,
float
epsilon
)
{
size_t
seed
=
hash_combine
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE
(
Descriptor
,
descriptor
,
AddRMSNorm
,
seed
,
y
->
desc
(),
residual_out
->
desc
(),
a
->
desc
(),
b
->
desc
(),
weight
->
desc
(),
epsilon
);
INFINIOP_WORKSPACE_TENSOR
(
workspace
,
AddRMSNorm
,
descriptor
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopAddRMSNormDescriptor_t
desc
=
nullptr
;
auto
planned
=
new
PlannedMeta
{
descriptor
,
graph
::
GraphTensor
(
workspace
),
graph
::
GraphTensor
(
y
),
graph
::
GraphTensor
(
residual_out
),
graph
::
GraphTensor
(
a
),
graph
::
GraphTensor
(
b
),
graph
::
GraphTensor
(
weight
),
epsilon
};
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateAddRMSNormDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
y
->
desc
(),
a
->
desc
(),
b
->
desc
(),
weight
->
desc
(),
epsilon
,
residual_out
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
desc
=
*
desc_opt
;
}
return
planned
;
}
size_t
workspace_size
=
0
;
INFINICORE_CHECK_ERROR
(
infiniopGetAddRMSNormWorkspaceSize
(
desc
,
&
workspace_size
));
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
void
run
(
void
*
planned_meta
)
{
auto
planned
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
INFINICORE_CHECK_ERROR
(
infiniopAddRMSNorm
(
desc
,
workspace
->
data
(),
workspace_size
,
y
->
data
(),
a
->
data
(),
b
->
data
(),
weight
->
data
(),
residual_out
->
data
(),
context
::
getStream
()));
planned
->
descriptor
->
desc
,
planned
->
workspace
->
data
(),
planned
->
workspace
->
numel
(),
planned
->
out
->
data
(),
planned
->
residual
->
data
(),
planned
->
a
->
data
(),
planned
->
b
->
data
(),
planned
->
weight
->
data
(),
context
::
getStream
()));
}
void
cleanup
(
void
**
planned_meta_ptr
)
{
delete
*
reinterpret_cast
<
PlannedMeta
**>
(
planned_meta_ptr
);
*
planned_meta_ptr
=
nullptr
;
}
static
bool
registered
=
[]()
{
AddRMSNorm
::
dispatcher
().
registerAll
(
&
calculate
,
false
);
return
true
;
}();
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE
(
AddRMSNorm
,
&
plan
,
&
run
,
&
cleanup
);
}
// namespace infinicore::op::add_rms_norm_impl::infiniop
src/infiniop/ops/add_rms_norm/add_rms_norm.h
View file @
0c204dfd
...
...
@@ -33,19 +33,19 @@
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t residual_out_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
float epsilon); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
void *residual_out, \
const void *a, \
const void *b, \
const void *weight, \
void *residual_out, \
void *stream) const; \
}; \
}
...
...
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
View file @
0c204dfd
...
...
@@ -10,19 +10,19 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
residual_out_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
,
residual_out_desc
);
float
epsilon
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
residual_out_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
);
CHECK_RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
result
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
T
>
infiniStatus_t
add_rmsnorm
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
T
*
w
,
T
*
residual_out
)
{
infiniStatus_t
add_rmsnorm
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
T
*
residual_out
,
const
T
*
a
,
const
T
*
b
,
const
T
*
w
)
{
const
size_t
batch_size
=
info
->
shape
[
0
];
const
size_t
nhead
=
info
->
ndim
()
>
2
?
info
->
shape
[
1
]
:
1
;
const
size_t
dim
=
info
->
dim
();
...
...
@@ -61,7 +61,7 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
}
template
<
typename
T
,
typename
Tw
>
infiniStatus_t
add_rmsnormHalfPrecision
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
Tw
*
w
,
T
*
residual_out
)
{
infiniStatus_t
add_rmsnormHalfPrecision
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
T
*
residual_out
,
const
T
*
a
,
const
T
*
b
,
const
Tw
*
w
)
{
static_assert
(
std
::
is_same
<
T
,
fp16_t
>::
value
||
std
::
is_same
<
T
,
bf16_t
>::
value
,
"T must be fp16_t or bf16_t"
);
...
...
@@ -112,32 +112,32 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
)
const
{
void
*
y
,
void
*
residual_out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
stream
)
const
{
if
(
_info
.
atype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
fp16_t
*
)
weight
,
(
fp16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
fp16_t
*
)
residual_out
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
fp16_t
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
float
*
)
weight
,
(
fp16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
fp16_t
*
)
residual_out
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
bf16_t
*
)
weight
,
(
fp16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
fp16_t
*
)
residual_out
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
bf16_t
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_BF16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
bf16_t
*
)
weight
,
(
bf16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
bf16_t
*
)
residual_out
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
bf16_t
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
float
*
)
weight
,
(
bf16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
bf16_t
*
)
residual_out
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
fp16_t
*
)
weight
,
(
bf16_t
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
bf16_t
*
)
residual_out
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
fp16_t
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
const
float
*
)
a
,
(
const
float
*
)
b
,
(
const
float
*
)
weight
,
(
float
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
float
*
)
residual_out
,
(
const
float
*
)
a
,
(
const
float
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F64
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
const
double
*
)
a
,
(
const
double
*
)
b
,
(
const
double
*
)
weight
,
(
double
*
)
residual_out
));
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
double
*
)
residual_out
,
(
const
double
*
)
a
,
(
const
double
*
)
b
,
(
const
double
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
...
...
src/infiniop/ops/add_rms_norm/info.h
View file @
0c204dfd
...
...
@@ -16,9 +16,9 @@ public:
float
epsilon
;
std
::
vector
<
size_t
>
shape
;
std
::
vector
<
ptrdiff_t
>
y_strides
;
std
::
vector
<
ptrdiff_t
>
residual_out_strides
;
std
::
vector
<
ptrdiff_t
>
a_strides
;
std
::
vector
<
ptrdiff_t
>
b_strides
;
std
::
vector
<
ptrdiff_t
>
residual_out_strides
;
bool
has_residual_out
;
size_t
ndim
()
const
{
return
shape
.
size
();
}
...
...
@@ -26,11 +26,11 @@ public:
static
utils
::
Result
<
AddRMSNormInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
residual_out_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
float
epsilon
)
{
auto
atype
=
y_desc
->
dtype
();
auto
wtype
=
weight_desc
->
dtype
();
...
...
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
View file @
0c204dfd
...
...
@@ -49,12 +49,12 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
residual_out_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
,
residual_out_desc
);
float
epsilon
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
residual_out_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
);
CHECK_RESULT
(
result
);
auto
info
=
result
.
take
();
...
...
@@ -122,8 +122,8 @@ infiniStatus_t launchKernel(
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
)
const
{
void
*
y
,
void
*
residual_out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
stream
)
const
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
...
...
src/infiniop/ops/add_rms_norm/operator.cc
View file @
0c204dfd
...
...
@@ -32,12 +32,12 @@
__C
infiniStatus_t
infiniopCreateAddRMSNormDescriptor
(
infiniopHandle_t
handle
,
infiniopAddRMSNormDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
residual_out_desc
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
float
epsilon
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
...
...
@@ -45,11 +45,11 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
handle, \
reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
residual_out_desc, \
a_desc, \
b_desc, \
weight_desc, \
epsilon, \
residual_out_desc)
epsilon)
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
...
...
@@ -116,16 +116,16 @@ __C infiniStatus_t infiniopAddRMSNorm(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
void
*
residual_out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y,
a, b, weight, residual_ou
t, stream)
->calculate(workspace, workspace_size, y,
residual_out, a, b, weigh
t, stream)
switch
(
desc
->
device_type
)
{
...
...
test/infinicore/ops/add_rms_norm.py
View file @
0c204dfd
...
...
@@ -30,8 +30,24 @@ _TEST_CASES_DATA = [
((
16
,
2048
),
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
(
4096
,
1
)),
((
15
,
3584
),
(
15
,
3584
),
(
15
,
3584
),
(
3584
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
)),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
)),
(
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
),
),
(
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
),
),
]
# Tolerance configuration
...
...
@@ -87,12 +103,14 @@ def parse_test_cases():
y_spec
=
TensorSpec
.
from_tensor
(
y_shape
,
y_strides
,
input_dtype
)
# Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result)
residual_out_spec
=
TensorSpec
.
from_tensor
(
a_shape
,
a_strides
,
input_dtype
)
residual_out_spec
=
TensorSpec
.
from_tensor
(
a_shape
,
a_strides
,
input_dtype
)
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"epsilon"
:
_EPSILON
},
output_specs
=
[
y_spec
,
residual_out_spec
]
,
# Two outputs
output_specs
=
None
,
# Two outputs
comparison_target
=
None
,
tolerance
=
tolerance
,
output_count
=
2
,
# Two outputs: normalized_result and add_result
...
...
@@ -101,19 +119,25 @@ def parse_test_cases():
)
# Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w))
if
y_supports_inplace
:
residual_out_spec
=
TensorSpec
.
from_tensor
(
a_shape
,
a_strides
,
input_dtype
)
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"epsilon"
:
_EPSILON
,
"out"
:
(
y_spec
,
residual_out_spec
)},
output_specs
=
[
y_spec
,
residual_out_spec
],
# Two outputs
comparison_target
=
"out"
,
tolerance
=
tolerance
,
output_count
=
2
,
description
=
f
"AddRMSNorm - INPLACE(out)"
,
)
)
# if y_supports_inplace:
# residual_out_spec = TensorSpec.from_tensor(
# a_shape, a_strides, input_dtype
# )
# test_cases.append(
# TestCase(
# inputs=[a_spec, b_spec, w_spec],
# kwargs={
# "epsilon": _EPSILON,
# "out": y_spec,
# "residual": residual_out_spec,
# },
# output_specs=[y_spec, residual_out_spec], # Two outputs
# comparison_target="out",
# tolerance=tolerance,
# output_count=2,
# description=f"AddRMSNorm - INPLACE(out)",
# )
# )
return
test_cases
...
...
@@ -127,7 +151,9 @@ class OpTest(BaseOperatorTest):
def
get_test_cases
(
self
):
return
parse_test_cases
()
def
torch_operator
(
self
,
a
,
b
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
**
kwargs
):
def
torch_operator
(
self
,
a
,
b
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
residual
=
None
,
**
kwargs
):
"""PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)"""
input_dtype
=
a
.
dtype
...
...
@@ -144,21 +170,19 @@ class OpTest(BaseOperatorTest):
add_result
=
sum_tensor
.
to
(
input_dtype
)
if
out
is
not
None
:
# For in-place operations, we need to handle the output tuple
if
isinstance
(
out
,
(
tuple
,
list
))
and
len
(
out
)
==
2
:
out
[
0
].
copy_
(
normalized_result
)
out
[
1
].
copy_
(
add_result
)
return
tuple
(
out
)
else
:
# Single output - just return normalized result for backward compatibility
out
.
copy_
(
normalized_result
)
return
out
out
.
copy_
(
normalized_result
)
if
residual
is
not
None
:
residual
.
copy_
(
add_result
)
return
(
normalized_result
,
add_result
)
def
infinicore_operator
(
self
,
a
,
b
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
**
kwargs
):
def
infinicore_operator
(
self
,
a
,
b
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
residual
=
None
,
**
kwargs
):
"""InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)"""
return
infinicore
.
add_rms_norm
(
a
,
b
,
weight
,
epsilon
,
out
=
out
)
return
infinicore
.
add_rms_norm
(
a
,
b
,
weight
,
epsilon
,
out
=
out
,
residual
=
residual
)
def
main
():
...
...
test/infiniop/add_rms_norm.py
View file @
0c204dfd
...
...
@@ -32,8 +32,24 @@ _TEST_CASES_ = [
((
16
,
2048
),
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
(
4096
,
1
)),
((
15
,
3584
),
(
15
,
3584
),
(
15
,
3584
),
(
3584
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
)),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
)),
(
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
),
),
(
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
),
),
((
15
,
3584
),
(
15
,
3584
),
(
15
,
3584
),
(
3584
,),
None
,
None
,
None
),
((
15
,
8192
),
(
15
,
8192
),
(
15
,
8192
),
(
8192
,),
None
,
None
,
None
),
]
...
...
@@ -97,7 +113,9 @@ def test(
w
=
TestTensor
(
w_shape
,
None
,
w_dtype
,
device
)
eps
=
1e-6
add_rms_norm
(
y
.
torch_tensor
(),
a
.
torch_tensor
(),
b
.
torch_tensor
(),
w
.
torch_tensor
(),
eps
)
add_rms_norm
(
y
.
torch_tensor
(),
a
.
torch_tensor
(),
b
.
torch_tensor
(),
w
.
torch_tensor
(),
eps
)
if
sync
is
not
None
:
sync
()
...
...
@@ -109,11 +127,11 @@ def test(
handle
,
ctypes
.
byref
(
descriptor
),
y
.
descriptor
,
residual_out
.
descriptor
,
a
.
descriptor
,
b
.
descriptor
,
w
.
descriptor
,
eps
,
residual_out
.
descriptor
,
)
)
...
...
@@ -136,10 +154,10 @@ def test(
workspace
.
data
(),
workspace_size
.
value
,
y
.
data
(),
residual_out
.
data
(),
a
.
data
(),
b
.
data
(),
w
.
data
(),
residual_out
.
data
(),
None
,
)
)
...
...
@@ -147,18 +165,22 @@ def test(
lib_add_rms_norm
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
# Verify normalized result (y)
if
DEBUG
:
debug
(
y
.
actual_tensor
(),
y
.
torch_tensor
(),
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
y
.
actual_tensor
(),
y
.
torch_tensor
(),
atol
=
atol
,
rtol
=
rtol
)
# Verify add result (residual_out) - should be a + b
expected_residual
=
a
.
torch_tensor
().
to
(
torch
.
float32
)
+
b
.
torch_tensor
().
to
(
torch
.
float32
)
expected_residual
=
a
.
torch_tensor
().
to
(
torch
.
float32
)
+
b
.
torch_tensor
().
to
(
torch
.
float32
)
expected_residual
=
expected_residual
.
to
(
a
.
torch_tensor
().
dtype
)
if
DEBUG
:
debug
(
residual_out
.
actual_tensor
(),
expected_residual
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
residual_out
.
actual_tensor
(),
expected_residual
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
residual_out
.
actual_tensor
(),
expected_residual
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
...
...
test/infiniop/libinfiniop/op_register.py
View file @
0c204dfd
...
...
@@ -393,6 +393,7 @@ def add_rms_norm_(lib):
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
c_float
,
]
...
...
@@ -412,6 +413,7 @@ def add_rms_norm_(lib):
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
lib
.
infiniopDestroyAddRMSNormDescriptor
.
restype
=
c_int32
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment