Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
2a432b34
Commit
2a432b34
authored
Dec 24, 2025
by
zhuyue
Committed by
gongchensu
Dec 24, 2025
Browse files
Unify add_rms_norm to always return (normalized_result, add_result) pair.
parent
7d60e5b8
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
196 additions
and
103 deletions
+196
-103
include/infinicore/ops/add_rms_norm.hpp
include/infinicore/ops/add_rms_norm.hpp
+8
-4
include/infiniop/ops/add_rms_norm.h
include/infiniop/ops/add_rms_norm.h
+3
-1
python/infinicore/__init__.py
python/infinicore/__init__.py
+2
-1
python/infinicore/ops/add_rms_norm.py
python/infinicore/ops/add_rms_norm.py
+23
-3
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
+9
-8
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
+4
-4
src/infinicore/pybind11/ops/add_rms_norm.hpp
src/infinicore/pybind11/ops/add_rms_norm.hpp
+5
-2
src/infiniop/ops/add_rms_norm/add_rms_norm.h
src/infiniop/ops/add_rms_norm/add_rms_norm.h
+3
-1
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
+66
-30
src/infiniop/ops/add_rms_norm/info.h
src/infiniop/ops/add_rms_norm/info.h
+29
-1
src/infiniop/ops/add_rms_norm/operator.cc
src/infiniop/ops/add_rms_norm/operator.cc
+6
-3
test/infinicore/ops/add_rms_norm.py
test/infinicore/ops/add_rms_norm.py
+25
-44
test/infiniop/add_rms_norm.py
test/infiniop/add_rms_norm.py
+13
-1
No files found.
include/infinicore/ops/add_rms_norm.hpp
View file @
2a432b34
...
...
@@ -2,15 +2,19 @@
#include "../device.hpp"
#include "common/op.hpp"
#include <utility>
namespace
infinicore
::
op
{
class
AddRMSNorm
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
float
);
static
void
execute
(
Tensor
y
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
float
);
static
void
execute
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
Tensor
add_rms_norm
(
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
void
add_rms_norm_
(
Tensor
y
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
// Fused Add and RMS Normalization
// Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers
std
::
pair
<
Tensor
,
Tensor
>
add_rms_norm
(
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
void
add_rms_norm_
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
}
// namespace infinicore::op
include/infiniop/ops/add_rms_norm.h
View file @
2a432b34
...
...
@@ -12,7 +12,8 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
);
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
);
__C
__export
infiniStatus_t
infiniopGetAddRMSNormWorkspaceSize
(
infiniopAddRMSNormDescriptor_t
desc
,
size_t
*
size
);
...
...
@@ -23,6 +24,7 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyAddRMSNormDescriptor
(
infiniopAddRMSNormDescriptor_t
desc
);
...
...
python/infinicore/__init__.py
View file @
2a432b34
...
...
@@ -40,7 +40,7 @@ from infinicore.dtype import (
uint8
,
)
from
infinicore.ops.add
import
add
from
infinicore.ops.add_rms_norm
import
add_rms_norm
from
infinicore.ops.add_rms_norm
import
add_rms_norm
,
add_rms_norm_
from
infinicore.ops.attention
import
attention
from
infinicore.ops.matmul
import
matmul
from
infinicore.ops.mul
import
mul
...
...
@@ -104,6 +104,7 @@ __all__ = [
# Operations.
"add"
,
"add_rms_norm"
,
"add_rms_norm_"
,
"attention"
,
"matmul"
,
"mul"
,
...
...
python/infinicore/ops/add_rms_norm.py
View file @
2a432b34
...
...
@@ -3,9 +3,29 @@ from infinicore.tensor import Tensor
def
add_rms_norm
(
a
,
b
,
weight
,
epsilon
=
1e-5
,
*
,
out
=
None
):
"""
Fused Add and RMS Normalization.
Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
out: Optional output tuple (y, residual_out) for in-place operation
Returns:
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
"""
if
out
is
None
:
return
Tensor
(
_infinicore
.
add_rms_norm
(
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
))
result
=
_infinicore
.
add_rms_norm
(
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
)
return
(
Tensor
(
result
[
0
]),
Tensor
(
result
[
1
]))
y
,
residual_out
=
out
_infinicore
.
add_rms_norm_
(
y
.
_underlying
,
residual_out
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
)
return
(
y
,
residual_out
)
_infinicore
.
add_rms_norm_
(
out
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
)
return
out
def
add_rms_norm_
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
=
1e-5
):
"""In-place Fused Add and RMS Normalization."""
_infinicore
.
add_rms_norm_
(
y
.
_underlying
,
residual_out
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
)
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
View file @
2a432b34
...
...
@@ -9,20 +9,21 @@ common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() {
return
dispatcher_
;
};
void
AddRMSNorm
::
execute
(
Tensor
y
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
y
,
a
,
b
,
weight
);
void
AddRMSNorm
::
execute
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
y
,
residual_out
,
a
,
b
,
weight
);
infinicore
::
context
::
setDevice
(
y
->
device
());
dispatcher
().
lookup
(
y
->
device
().
getType
())(
y
,
a
,
b
,
weight
,
epsilon
);
dispatcher
().
lookup
(
y
->
device
().
getType
())(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
}
Tensor
add_rms_norm
(
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
std
::
pair
<
Tensor
,
Tensor
>
add_rms_norm
(
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
auto
y
=
Tensor
::
empty
(
a
->
shape
(),
a
->
dtype
(),
a
->
device
());
add_rms_norm_
(
y
,
a
,
b
,
weight
,
epsilon
);
return
y
;
auto
residual_out
=
Tensor
::
empty
(
a
->
shape
(),
a
->
dtype
(),
a
->
device
());
add_rms_norm_
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
return
std
::
make_pair
(
y
,
residual_out
);
}
void
add_rms_norm_
(
Tensor
y
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
AddRMSNorm
::
execute
(
y
,
a
,
b
,
weight
,
epsilon
);
void
add_rms_norm_
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
AddRMSNorm
::
execute
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
}
}
// namespace infinicore::op
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
View file @
2a432b34
...
...
@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches(
}
});
void
calculate
(
Tensor
y
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
size_t
seed
=
hash_combine
(
y
,
a
,
b
,
weight
,
epsilon
);
void
calculate
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
size_t
seed
=
hash_combine
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
...
...
@@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateAddRMSNormDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
y
->
desc
(),
a
->
desc
(),
b
->
desc
(),
weight
->
desc
(),
epsilon
));
y
->
desc
(),
a
->
desc
(),
b
->
desc
(),
weight
->
desc
(),
epsilon
,
residual_out
->
desc
()
));
cache
.
put
(
seed
,
desc
);
}
else
{
desc
=
*
desc_opt
;
...
...
@@ -39,7 +39,7 @@ void calculate(Tensor y, Tensor a, Tensor b, Tensor weight, float epsilon) {
INFINICORE_CHECK_ERROR
(
infiniopAddRMSNorm
(
desc
,
workspace
->
data
(),
workspace_size
,
y
->
data
(),
a
->
data
(),
b
->
data
(),
weight
->
data
(),
context
::
getStream
()));
y
->
data
(),
a
->
data
(),
b
->
data
(),
weight
->
data
(),
residual_out
->
data
(),
context
::
getStream
()));
}
static
bool
registered
=
[]()
{
...
...
src/infinicore/pybind11/ops/add_rms_norm.hpp
View file @
2a432b34
...
...
@@ -24,12 +24,14 @@ Args:
epsilon: Small constant for numerical stability, default is 1e-5
Returns:
Normalized tensor: RMSNorm(a + b) * weight
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
)doc"
);
m
.
def
(
"add_rms_norm_"
,
&
op
::
add_rms_norm_
,
py
::
arg
(
"y"
),
py
::
arg
(
"residual_out"
),
py
::
arg
(
"a"
),
py
::
arg
(
"b"
),
py
::
arg
(
"weight"
),
...
...
@@ -37,7 +39,8 @@ Returns:
R"doc(In-place Fused Add and RMS Normalization.
Args:
y: Output tensor
y: Output tensor for normalized result
residual_out: Output tensor for add result (a + b) before normalization
a: First input tensor
b: Second input tensor
weight: Scale weights
...
...
src/infiniop/ops/add_rms_norm/add_rms_norm.h
View file @
2a432b34
...
...
@@ -36,7 +36,8 @@
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon); \
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
...
...
@@ -44,6 +45,7 @@
const void *a, \
const void *b, \
const void *weight, \
void *residual_out, \
void *stream) const; \
}; \
}
...
...
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
View file @
2a432b34
...
...
@@ -13,15 +13,16 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
);
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
,
residual_out_desc
);
CHECK_RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
result
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
T
>
infiniStatus_t
add_rmsnorm
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
T
*
w
)
{
infiniStatus_t
add_rmsnorm
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
T
*
w
,
T
*
residual_out
)
{
const
size_t
batch_size
=
info
->
shape
[
0
];
const
size_t
nhead
=
info
->
ndim
()
>
2
?
info
->
shape
[
1
]
:
1
;
const
size_t
dim
=
info
->
dim
();
...
...
@@ -35,12 +36,16 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
const
T
*
a_ptr
=
a
+
i
*
info
->
a_strides
[
0
]
+
j
*
info
->
a_strides
[
1
];
const
T
*
b_ptr
=
b
+
i
*
info
->
b_strides
[
0
]
+
j
*
info
->
b_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
T
*
residual_out_ptr
=
info
->
has_residual_out
?
(
residual_out
+
i
*
info
->
residual_out_strides
[
0
]
+
j
*
info
->
residual_out_strides
[
1
])
:
nullptr
;
// First, compute add(a, b) and store sum values
// We'll compute RMS norm directly on the sum
// Compute add(a, b) once and store it
T
sum_squared
=
(
T
)
0
;
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
T
sum_val
=
a_ptr
[
k
]
+
b_ptr
[
k
];
if
(
residual_out_ptr
!=
nullptr
)
{
residual_out_ptr
[
k
]
=
sum_val
;
// Store add result
}
sum_squared
+=
sum_val
*
sum_val
;
}
...
...
@@ -49,10 +54,18 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
T
rms
=
(
T
)
1
/
std
::
sqrt
(
sum_squared
/
(
T
)(
dim
)
+
(
T
)(
info
->
epsilon
));
// Apply normalization: y = (a + b) * w * rms
// Recompute sum to avoid storing temporary array
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
T
sum_val
=
a_ptr
[
k
]
+
b_ptr
[
k
];
y_ptr
[
k
]
=
sum_val
*
w
[
k
]
*
rms
;
// Reuse the stored sum values if residual_out was computed, otherwise recompute
if
(
residual_out_ptr
!=
nullptr
)
{
// Reuse stored values
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
y_ptr
[
k
]
=
residual_out_ptr
[
k
]
*
w
[
k
]
*
rms
;
}
}
else
{
// Recompute sum
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
T
sum_val
=
a_ptr
[
k
]
+
b_ptr
[
k
];
y_ptr
[
k
]
=
sum_val
*
w
[
k
]
*
rms
;
}
}
}
...
...
@@ -60,7 +73,7 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
}
template
<
typename
T
,
typename
Tw
>
infiniStatus_t
add_rmsnormHalfPrecision
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
Tw
*
w
)
{
infiniStatus_t
add_rmsnormHalfPrecision
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
Tw
*
w
,
T
*
residual_out
)
{
static_assert
(
std
::
is_same
<
T
,
fp16_t
>::
value
||
std
::
is_same
<
T
,
bf16_t
>::
value
,
"T must be fp16_t or bf16_t"
);
...
...
@@ -77,11 +90,16 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
const
T
*
a_ptr
=
a
+
i
*
info
->
a_strides
[
0
]
+
j
*
info
->
a_strides
[
1
];
const
T
*
b_ptr
=
b
+
i
*
info
->
b_strides
[
0
]
+
j
*
info
->
b_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
T
*
residual_out_ptr
=
info
->
has_residual_out
?
(
residual_out
+
i
*
info
->
residual_out_strides
[
0
]
+
j
*
info
->
residual_out_strides
[
1
])
:
nullptr
;
// Compute sum of squares for RMS normalization
// Compute sum of squares for RMS normalization
and store add result
float
sum_squared
=
0.0
f
;
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
a_ptr
[
k
])
+
utils
::
cast
<
float
>
(
b_ptr
[
k
]);
if
(
residual_out_ptr
!=
nullptr
)
{
residual_out_ptr
[
k
]
=
utils
::
cast
<
T
>
(
sum_val
);
// Store add result
}
sum_squared
+=
sum_val
*
sum_val
;
}
...
...
@@ -89,17 +107,35 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
float
rms
=
1.
f
/
std
::
sqrt
(
sum_squared
/
(
float
)(
dim
)
+
info
->
epsilon
);
// Apply normalization: y = (a + b) * w * rms
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
a_ptr
[
k
])
+
utils
::
cast
<
float
>
(
b_ptr
[
k
]);
float
val
;
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
val
=
sum_val
*
w
[
k
]
*
rms
;
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
T
>::
value
||
std
::
is_same_v
<
Tw
,
fp16_t
>
||
std
::
is_same_v
<
Tw
,
bf16_t
>
)
{
val
=
sum_val
*
utils
::
cast
<
float
>
(
w
[
k
])
*
rms
;
}
else
{
std
::
abort
();
// Reuse stored values if residual_out was computed, otherwise recompute
if
(
residual_out_ptr
!=
nullptr
)
{
// Reuse stored values
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
residual_out_ptr
[
k
]);
float
val
;
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
val
=
sum_val
*
w
[
k
]
*
rms
;
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
T
>::
value
||
std
::
is_same_v
<
Tw
,
fp16_t
>
||
std
::
is_same_v
<
Tw
,
bf16_t
>
)
{
val
=
sum_val
*
utils
::
cast
<
float
>
(
w
[
k
])
*
rms
;
}
else
{
std
::
abort
();
}
y_ptr
[
k
]
=
utils
::
cast
<
T
>
(
val
);
}
}
else
{
// Recompute sum
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
a_ptr
[
k
])
+
utils
::
cast
<
float
>
(
b_ptr
[
k
]);
float
val
;
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
val
=
sum_val
*
w
[
k
]
*
rms
;
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
T
>::
value
||
std
::
is_same_v
<
Tw
,
fp16_t
>
||
std
::
is_same_v
<
Tw
,
bf16_t
>
)
{
val
=
sum_val
*
utils
::
cast
<
float
>
(
w
[
k
])
*
rms
;
}
else
{
std
::
abort
();
}
y_ptr
[
k
]
=
utils
::
cast
<
T
>
(
val
);
}
y_ptr
[
k
]
=
utils
::
cast
<
T
>
(
val
);
}
}
...
...
@@ -109,31 +145,31 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
stream
)
const
{
void
*
residual_out
,
void
*
stream
)
const
{
if
(
_info
.
atype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
fp16_t
*
)
weight
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
fp16_t
*
)
weight
,
(
fp16_t
*
)
residual_out
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
float
*
)
weight
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
float
*
)
weight
,
(
fp16_t
*
)
residual_out
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
bf16_t
*
)
weight
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
bf16_t
*
)
weight
,
(
fp16_t
*
)
residual_out
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_BF16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
bf16_t
*
)
weight
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
bf16_t
*
)
weight
,
(
bf16_t
*
)
residual_out
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
float
*
)
weight
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
float
*
)
weight
,
(
bf16_t
*
)
residual_out
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
fp16_t
*
)
weight
));
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
fp16_t
*
)
weight
,
(
bf16_t
*
)
residual_out
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
const
float
*
)
a
,
(
const
float
*
)
b
,
(
const
float
*
)
weight
));
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
const
float
*
)
a
,
(
const
float
*
)
b
,
(
const
float
*
)
weight
,
(
float
*
)
residual_out
));
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F64
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
const
double
*
)
a
,
(
const
double
*
)
b
,
(
const
double
*
)
weight
));
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
const
double
*
)
a
,
(
const
double
*
)
b
,
(
const
double
*
)
weight
,
(
double
*
)
residual_out
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
...
...
src/infiniop/ops/add_rms_norm/info.h
View file @
2a432b34
...
...
@@ -18,6 +18,8 @@ public:
std
::
vector
<
ptrdiff_t
>
y_strides
;
std
::
vector
<
ptrdiff_t
>
a_strides
;
std
::
vector
<
ptrdiff_t
>
b_strides
;
std
::
vector
<
ptrdiff_t
>
residual_out_strides
;
bool
has_residual_out
;
size_t
ndim
()
const
{
return
shape
.
size
();
}
size_t
dim
()
const
{
return
shape
[
ndim
()
-
1
];
}
...
...
@@ -27,7 +29,8 @@ public:
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
)
{
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
auto
atype
=
y_desc
->
dtype
();
auto
wtype
=
weight_desc
->
dtype
();
...
...
@@ -95,6 +98,27 @@ public:
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
// Check residual_out_desc if provided
bool
has_residual_out
=
(
residual_out_desc
!=
nullptr
);
if
(
has_residual_out
)
{
const
size_t
residual_out_ndim
=
residual_out_desc
->
ndim
();
if
(
residual_out_ndim
!=
y_ndim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
residual_out_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
// Check shape matches
for
(
size_t
i
=
0
;
i
<
y_ndim
;
i
++
)
{
if
(
residual_out_desc
->
dim
(
i
)
!=
y_desc
->
dim
(
i
))
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
if
(
residual_out_desc
->
stride
(
residual_out_ndim
-
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
}
AddRMSNormInfo
info
;
info
.
wtype
=
wtype
;
info
.
atype
=
atype
;
...
...
@@ -103,6 +127,10 @@ public:
info
.
y_strides
=
y_desc
->
strides
();
info
.
a_strides
=
a_desc
->
strides
();
info
.
b_strides
=
b_desc
->
strides
();
info
.
has_residual_out
=
has_residual_out
;
if
(
has_residual_out
)
{
info
.
residual_out_strides
=
residual_out_desc
->
strides
();
}
return
utils
::
Result
<
AddRMSNormInfo
>
(
info
);
}
};
...
...
src/infiniop/ops/add_rms_norm/operator.cc
View file @
2a432b34
...
...
@@ -37,7 +37,8 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
)
{
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
...
...
@@ -48,7 +49,8 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
a_desc, \
b_desc, \
weight_desc, \
epsilon)
epsilon, \
residual_out_desc)
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
...
...
@@ -118,12 +120,13 @@ __C infiniStatus_t infiniopAddRMSNorm(
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, a, b, weight, stream)
->calculate(workspace, workspace_size, y, a, b, weight,
residual_out,
stream)
switch
(
desc
->
device_type
)
{
...
...
test/infinicore/ops/add_rms_norm.py
View file @
2a432b34
...
...
@@ -86,63 +86,35 @@ def parse_test_cases():
)
# Weight is always contiguous
y_spec
=
TensorSpec
.
from_tensor
(
y_shape
,
y_strides
,
input_dtype
)
# Test Case 1: Out-of-place (return value)
# Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result)
residual_out_spec
=
TensorSpec
.
from_tensor
(
a_shape
,
a_strides
,
input_dtype
)
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"epsilon"
:
_EPSILON
},
output_spec
=
None
,
output_spec
s
=
[
y_spec
,
residual_out_spec
],
# Two outputs
comparison_target
=
None
,
tolerance
=
tolerance
,
output_count
=
2
,
# Two outputs: normalized_result and add_result
description
=
f
"AddRMSNorm - OUT_OF_PLACE"
,
)
)
# Test Case 2: In-place with explicit output tensor (add_rms_norm
(a, b, w, out=y
))
# Test Case 2: In-place with explicit output tensor
s
(add_rms_norm
_(y, residual_out, a, b, w
))
if
y_supports_inplace
:
residual_out_spec
=
TensorSpec
.
from_tensor
(
a_shape
,
a_strides
,
input_dtype
)
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"epsilon"
:
_EPSILON
},
output_spec
=
y_spec
,
# Specify the output tensor spec
kwargs
=
{
"epsilon"
:
_EPSILON
,
"out"
:
(
y_spec
,
residual_out_spec
)
},
output_spec
s
=
[
y_spec
,
residual_out_spec
],
# Two outputs
comparison_target
=
"out"
,
tolerance
=
tolerance
,
output_count
=
2
,
description
=
f
"AddRMSNorm - INPLACE(out)"
,
)
)
# Test Case 3: In-place on first input (add_rms_norm(a, b, w, out=a))
if
a_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"out"
:
0
,
"epsilon"
:
_EPSILON
,
},
# Use index 0 for first input
output_spec
=
None
,
comparison_target
=
0
,
# Compare first input
tolerance
=
tolerance
,
description
=
f
"AddRMSNorm - INPLACE(a)"
,
)
)
# Test Case 4: In-place on second input (add_rms_norm(a, b, w, out=b))
if
b_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"out"
:
1
,
"epsilon"
:
_EPSILON
,
},
# Use index 1 for second input
output_spec
=
None
,
comparison_target
=
1
,
# Compare second input
tolerance
=
tolerance
,
description
=
f
"AddRMSNorm - INPLACE(b)"
,
)
)
return
test_cases
...
...
@@ -156,7 +128,7 @@ class OpTest(BaseOperatorTest):
return
parse_test_cases
()
def
torch_operator
(
self
,
a
,
b
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
**
kwargs
):
"""PyTorch AddRMSNorm implementation"""
"""PyTorch AddRMSNorm implementation
- returns (normalized_result, add_result)
"""
input_dtype
=
a
.
dtype
# Compute add(a, b)
...
...
@@ -165,18 +137,27 @@ class OpTest(BaseOperatorTest):
# Calculate RMSNorm: (a + b) * weight / sqrt(mean((a+b)^2) + epsilon)
variance
=
sum_tensor
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
result
=
sum_tensor
*
torch
.
rsqrt
(
variance
+
epsilon
)
*
weight_fp32
normalized_
result
=
sum_tensor
*
torch
.
rsqrt
(
variance
+
epsilon
)
*
weight_fp32
# Convert back to original dtype
result
=
result
.
to
(
input_dtype
)
normalized_result
=
normalized_result
.
to
(
input_dtype
)
add_result
=
sum_tensor
.
to
(
input_dtype
)
if
out
is
not
None
:
out
.
copy_
(
result
)
return
out
return
result
# For in-place operations, we need to handle the output tuple
if
isinstance
(
out
,
(
tuple
,
list
))
and
len
(
out
)
==
2
:
out
[
0
].
copy_
(
normalized_result
)
out
[
1
].
copy_
(
add_result
)
return
tuple
(
out
)
else
:
# Single output - just return normalized result for backward compatibility
out
.
copy_
(
normalized_result
)
return
out
return
(
normalized_result
,
add_result
)
def
infinicore_operator
(
self
,
a
,
b
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
**
kwargs
):
"""InfiniCore AddRMSNorm implementation"""
"""InfiniCore AddRMSNorm implementation
- returns (normalized_result, add_result)
"""
return
infinicore
.
add_rms_norm
(
a
,
b
,
weight
,
epsilon
,
out
=
out
)
...
...
test/infiniop/add_rms_norm.py
View file @
2a432b34
...
...
@@ -91,6 +91,7 @@ def test(
)
y
=
TestTensor
(
y_shape
,
y_stride
,
dtype
,
device
,
mode
=
"ones"
)
residual_out
=
TestTensor
(
a_shape
,
a_stride
,
dtype
,
device
,
mode
=
"ones"
)
a
=
TestTensor
(
a_shape
,
a_stride
,
dtype
,
device
,
scale
=
0.01
)
b
=
TestTensor
(
b_shape
,
b_stride
,
dtype
,
device
,
scale
=
0.01
)
w
=
TestTensor
(
w_shape
,
None
,
w_dtype
,
device
)
...
...
@@ -112,11 +113,12 @@ def test(
b
.
descriptor
,
w
.
descriptor
,
eps
,
residual_out
.
descriptor
,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for
tensor
in
[
a
,
b
,
y
,
w
]:
for
tensor
in
[
a
,
b
,
y
,
w
,
residual_out
]:
tensor
.
destroy_desc
()
workspace_size
=
c_uint64
(
0
)
...
...
@@ -137,6 +139,7 @@ def test(
a
.
data
(),
b
.
data
(),
w
.
data
(),
residual_out
.
data
(),
None
,
)
)
...
...
@@ -144,9 +147,18 @@ def test(
lib_add_rms_norm
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
# Verify normalized result (y)
if
DEBUG
:
debug
(
y
.
actual_tensor
(),
y
.
torch_tensor
(),
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
y
.
actual_tensor
(),
y
.
torch_tensor
(),
atol
=
atol
,
rtol
=
rtol
)
# Verify add result (residual_out) - should be a + b
expected_residual
=
a
.
torch_tensor
().
to
(
torch
.
float32
)
+
b
.
torch_tensor
().
to
(
torch
.
float32
)
expected_residual
=
expected_residual
.
to
(
a
.
torch_tensor
().
dtype
)
if
DEBUG
:
debug
(
residual_out
.
actual_tensor
(),
expected_residual
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
residual_out
.
actual_tensor
(),
expected_residual
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment