Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
3b5afffe
Unverified
Commit
3b5afffe
authored
Jan 07, 2026
by
Haojie Wang
Committed by
GitHub
Jan 07, 2026
Browse files
Merge pull request #842 from gongchensu/Issue/791
Issue/791 增加add_rms_norm融合算子
parents
2d9d5c30
7712471f
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1366 additions
and
0 deletions
+1366
-0
include/infinicore/ops.hpp
include/infinicore/ops.hpp
+1
-0
include/infinicore/ops/add_rms_norm.hpp
include/infinicore/ops/add_rms_norm.hpp
+20
-0
include/infiniop.h
include/infiniop.h
+1
-0
include/infiniop/ops/add_rms_norm.h
include/infiniop/ops/add_rms_norm.h
+32
-0
python/infinicore/__init__.py
python/infinicore/__init__.py
+3
-0
python/infinicore/ops/add_rms_norm.py
python/infinicore/ops/add_rms_norm.py
+47
-0
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
+29
-0
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
+50
-0
src/infinicore/pybind11/ops.hpp
src/infinicore/pybind11/ops.hpp
+2
-0
src/infinicore/pybind11/ops/add_rms_norm.hpp
src/infinicore/pybind11/ops/add_rms_norm.hpp
+51
-0
src/infiniop/ops/add_rms_norm/add_rms_norm.h
src/infiniop/ops/add_rms_norm/add_rms_norm.h
+53
-0
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
+147
-0
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h
+7
-0
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
+63
-0
src/infiniop/ops/add_rms_norm/info.h
src/infiniop/ops/add_rms_norm/info.h
+132
-0
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
+175
-0
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cuh
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cuh
+8
-0
src/infiniop/ops/add_rms_norm/operator.cc
src/infiniop/ops/add_rms_norm/operator.cc
+189
-0
test/infinicore/ops/add_rms_norm.py
test/infinicore/ops/add_rms_norm.py
+171
-0
test/infiniop/add_rms_norm.py
test/infiniop/add_rms_norm.py
+185
-0
No files found.
include/infinicore/ops.hpp
View file @
3b5afffe
#pragma once
#pragma once
#include "ops/add.hpp"
#include "ops/add.hpp"
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
#include "ops/matmul.hpp"
...
...
include/infinicore/ops/add_rms_norm.hpp
0 → 100644
View file @
3b5afffe
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
#include <utility>
namespace
infinicore
::
op
{
class
AddRMSNorm
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
float
);
static
void
execute
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
// Fused Add and RMS Normalization
// Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers
std
::
pair
<
Tensor
,
Tensor
>
add_rms_norm
(
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
void
add_rms_norm_
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
}
// namespace infinicore::op
include/infiniop.h
View file @
3b5afffe
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "infiniop/handle.h"
#include "infiniop/handle.h"
#include "infiniop/ops/add.h"
#include "infiniop/ops/add.h"
#include "infiniop/ops/add_rms_norm.h"
#include "infiniop/ops/attention.h"
#include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/clip.h"
#include "infiniop/ops/clip.h"
...
...
include/infiniop/ops/add_rms_norm.h
0 → 100644
View file @
3b5afffe
#ifndef __INFINIOP_ADD_RMS_NORM_API_H__
#define __INFINIOP_ADD_RMS_NORM_API_H__
#include "../operator_descriptor.h"
typedef
struct
InfiniopDescriptor
*
infiniopAddRMSNormDescriptor_t
;
__C
__export
infiniStatus_t
infiniopCreateAddRMSNormDescriptor
(
infiniopHandle_t
handle
,
infiniopAddRMSNormDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
);
__C
__export
infiniStatus_t
infiniopGetAddRMSNormWorkspaceSize
(
infiniopAddRMSNormDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopAddRMSNorm
(
infiniopAddRMSNormDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyAddRMSNormDescriptor
(
infiniopAddRMSNormDescriptor_t
desc
);
#endif
python/infinicore/__init__.py
View file @
3b5afffe
...
@@ -40,6 +40,7 @@ from infinicore.dtype import (
...
@@ -40,6 +40,7 @@ from infinicore.dtype import (
uint8
,
uint8
,
)
)
from
infinicore.ops.add
import
add
from
infinicore.ops.add
import
add
from
infinicore.ops.add_rms_norm
import
add_rms_norm
,
add_rms_norm_
from
infinicore.ops.attention
import
attention
from
infinicore.ops.attention
import
attention
from
infinicore.ops.matmul
import
matmul
from
infinicore.ops.matmul
import
matmul
from
infinicore.ops.mul
import
mul
from
infinicore.ops.mul
import
mul
...
@@ -105,6 +106,8 @@ __all__ = [
...
@@ -105,6 +106,8 @@ __all__ = [
"uint8"
,
"uint8"
,
# Operations.
# Operations.
"add"
,
"add"
,
"add_rms_norm"
,
"add_rms_norm_"
,
"attention"
,
"attention"
,
"matmul"
,
"matmul"
,
"mul"
,
"mul"
,
...
...
python/infinicore/ops/add_rms_norm.py
0 → 100644
View file @
3b5afffe
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
def
add_rms_norm
(
a
,
b
,
weight
,
epsilon
=
1e-5
,
*
,
out
=
None
):
"""
Fused Add and RMS Normalization.
Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
out: Optional output tuple (y, residual_out) for in-place operation
Returns:
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
"""
if
out
is
None
:
result
=
_infinicore
.
add_rms_norm
(
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
)
return
(
Tensor
(
result
[
0
]),
Tensor
(
result
[
1
]))
y
,
residual_out
=
out
_infinicore
.
add_rms_norm_
(
y
.
_underlying
,
residual_out
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
,
)
return
(
y
,
residual_out
)
def
add_rms_norm_
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
=
1e-5
):
"""In-place Fused Add and RMS Normalization."""
_infinicore
.
add_rms_norm_
(
y
.
_underlying
,
residual_out
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
,
)
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
0 → 100644
View file @
3b5afffe
#include "infinicore/ops/add_rms_norm.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
AddRMSNorm
::
schema
>
&
AddRMSNorm
::
dispatcher
()
{
static
common
::
OpDispatcher
<
AddRMSNorm
::
schema
>
dispatcher_
;
return
dispatcher_
;
};
void
AddRMSNorm
::
execute
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
y
,
residual_out
,
a
,
b
,
weight
);
infinicore
::
context
::
setDevice
(
y
->
device
());
dispatcher
().
lookup
(
y
->
device
().
getType
())(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
}
std
::
pair
<
Tensor
,
Tensor
>
add_rms_norm
(
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
auto
y
=
Tensor
::
empty
(
a
->
shape
(),
a
->
dtype
(),
a
->
device
());
auto
residual_out
=
Tensor
::
empty
(
a
->
shape
(),
a
->
dtype
(),
a
->
device
());
add_rms_norm_
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
return
std
::
make_pair
(
y
,
residual_out
);
}
void
add_rms_norm_
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
AddRMSNorm
::
execute
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
}
}
// namespace infinicore::op
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
0 → 100644
View file @
3b5afffe
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add_rms_norm.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>
namespace
infinicore
::
op
::
add_rms_norm_impl
::
infiniop
{
thread_local
common
::
OpCache
<
size_t
,
infiniopAddRMSNormDescriptor_t
>
caches
(
100
,
// capacity
[](
infiniopAddRMSNormDescriptor_t
&
desc
)
{
if
(
desc
!=
nullptr
)
{
INFINICORE_CHECK_ERROR
(
infiniopDestroyAddRMSNormDescriptor
(
desc
));
desc
=
nullptr
;
}
});
void
calculate
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
size_t
seed
=
hash_combine
(
y
,
residual_out
,
a
,
b
,
weight
,
epsilon
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopAddRMSNormDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateAddRMSNormDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
y
->
desc
(),
a
->
desc
(),
b
->
desc
(),
weight
->
desc
(),
epsilon
,
residual_out
->
desc
()));
cache
.
put
(
seed
,
desc
);
}
else
{
desc
=
*
desc_opt
;
}
size_t
workspace_size
=
0
;
INFINICORE_CHECK_ERROR
(
infiniopGetAddRMSNormWorkspaceSize
(
desc
,
&
workspace_size
));
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
INFINICORE_CHECK_ERROR
(
infiniopAddRMSNorm
(
desc
,
workspace
->
data
(),
workspace_size
,
y
->
data
(),
a
->
data
(),
b
->
data
(),
weight
->
data
(),
residual_out
->
data
(),
context
::
getStream
()));
}
static
bool
registered
=
[]()
{
AddRMSNorm
::
dispatcher
().
registerAll
(
&
calculate
,
false
);
return
true
;
}();
}
// namespace infinicore::op::add_rms_norm_impl::infiniop
src/infinicore/pybind11/ops.hpp
View file @
3b5afffe
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
#include "ops/add.hpp"
#include "ops/add.hpp"
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/embedding.hpp"
...
@@ -24,6 +25,7 @@ namespace infinicore::ops {
...
@@ -24,6 +25,7 @@ namespace infinicore::ops {
inline
void
bind
(
py
::
module
&
m
)
{
inline
void
bind
(
py
::
module
&
m
)
{
bind_add
(
m
);
bind_add
(
m
);
bind_add_rms_norm
(
m
);
bind_attention
(
m
);
bind_attention
(
m
);
bind_causal_softmax
(
m
);
bind_causal_softmax
(
m
);
bind_random_sample
(
m
);
bind_random_sample
(
m
);
...
...
src/infinicore/pybind11/ops/add_rms_norm.hpp
0 → 100644
View file @
3b5afffe
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/add_rms_norm.hpp"
namespace
py
=
pybind11
;
namespace
infinicore
::
ops
{
inline
void
bind_add_rms_norm
(
py
::
module
&
m
)
{
m
.
def
(
"add_rms_norm"
,
&
op
::
add_rms_norm
,
py
::
arg
(
"a"
),
py
::
arg
(
"b"
),
py
::
arg
(
"weight"
),
py
::
arg
(
"epsilon"
)
=
1e-5
f
,
R"doc(Fused Add and RMS Normalization.
Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
Returns:
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
)doc"
);
m
.
def
(
"add_rms_norm_"
,
&
op
::
add_rms_norm_
,
py
::
arg
(
"y"
),
py
::
arg
(
"residual_out"
),
py
::
arg
(
"a"
),
py
::
arg
(
"b"
),
py
::
arg
(
"weight"
),
py
::
arg
(
"epsilon"
)
=
1e-5
f
,
R"doc(In-place Fused Add and RMS Normalization.
Args:
y: Output tensor for normalized result
residual_out: Output tensor for add result (a + b) before normalization
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
)doc"
);
}
}
// namespace infinicore::ops
src/infiniop/ops/add_rms_norm/add_rms_norm.h
0 → 100644
View file @
3b5afffe
#ifndef ADD_RMS_NORM_H
#define ADD_RMS_NORM_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::add_rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
AddRMSNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
AddRMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *a, \
const void *b, \
const void *weight, \
void *residual_out, \
void *stream) const; \
}; \
}
#endif // ADD_RMS_NORM_H
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
0 → 100644
View file @
3b5afffe
#include "add_rms_norm_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
namespace
op
::
add_rms_norm
::
cpu
{
Descriptor
::~
Descriptor
()
{}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
,
residual_out_desc
);
CHECK_RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
result
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
T
>
infiniStatus_t
add_rmsnorm
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
T
*
w
,
T
*
residual_out
)
{
const
size_t
batch_size
=
info
->
shape
[
0
];
const
size_t
nhead
=
info
->
ndim
()
>
2
?
info
->
shape
[
1
]
:
1
;
const
size_t
dim
=
info
->
dim
();
const
ptrdiff_t
total_blocks
=
static_cast
<
ptrdiff_t
>
(
batch_size
*
nhead
);
#pragma omp parallel for
for
(
ptrdiff_t
block_idx
=
0
;
block_idx
<
total_blocks
;
++
block_idx
)
{
const
size_t
i
=
block_idx
/
nhead
;
// batch index
const
size_t
j
=
block_idx
%
nhead
;
// head index
const
T
*
a_ptr
=
a
+
i
*
info
->
a_strides
[
0
]
+
j
*
info
->
a_strides
[
1
];
const
T
*
b_ptr
=
b
+
i
*
info
->
b_strides
[
0
]
+
j
*
info
->
b_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
T
*
residual_out_ptr
=
residual_out
+
i
*
info
->
residual_out_strides
[
0
]
+
j
*
info
->
residual_out_strides
[
1
];
// Compute add(a, b) once and store it
T
sum_squared
=
(
T
)
0
;
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
T
sum_val
=
a_ptr
[
k
]
+
b_ptr
[
k
];
residual_out_ptr
[
k
]
=
sum_val
;
// Store add result
sum_squared
+=
sum_val
*
sum_val
;
}
// Compute RMS: 1 / (sqrt(mean(sum^2) + eps))
// Note: mean = sum_squared / dim
T
rms
=
(
T
)
1
/
std
::
sqrt
(
sum_squared
/
(
T
)(
dim
)
+
(
T
)(
info
->
epsilon
));
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
y_ptr
[
k
]
=
residual_out_ptr
[
k
]
*
w
[
k
]
*
rms
;
}
}
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
T
,
typename
Tw
>
infiniStatus_t
add_rmsnormHalfPrecision
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
Tw
*
w
,
T
*
residual_out
)
{
static_assert
(
std
::
is_same
<
T
,
fp16_t
>::
value
||
std
::
is_same
<
T
,
bf16_t
>::
value
,
"T must be fp16_t or bf16_t"
);
const
size_t
batch_size
=
info
->
shape
[
0
];
const
size_t
nhead
=
info
->
ndim
()
>
2
?
info
->
shape
[
1
]
:
1
;
const
size_t
dim
=
info
->
dim
();
const
ptrdiff_t
total_blocks
=
static_cast
<
ptrdiff_t
>
(
batch_size
*
nhead
);
#pragma omp parallel for
for
(
ptrdiff_t
block_idx
=
0
;
block_idx
<
total_blocks
;
++
block_idx
)
{
const
size_t
i
=
block_idx
/
nhead
;
// batch index
const
size_t
j
=
block_idx
%
nhead
;
// head index
const
T
*
a_ptr
=
a
+
i
*
info
->
a_strides
[
0
]
+
j
*
info
->
a_strides
[
1
];
const
T
*
b_ptr
=
b
+
i
*
info
->
b_strides
[
0
]
+
j
*
info
->
b_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
T
*
residual_out_ptr
=
residual_out
+
i
*
info
->
residual_out_strides
[
0
]
+
j
*
info
->
residual_out_strides
[
1
];
// Compute sum of squares for RMS normalization and store add result
float
sum_squared
=
0.0
f
;
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
a_ptr
[
k
])
+
utils
::
cast
<
float
>
(
b_ptr
[
k
]);
residual_out_ptr
[
k
]
=
utils
::
cast
<
T
>
(
sum_val
);
// Store add result
sum_squared
+=
sum_val
*
sum_val
;
}
// Compute RMS: 1 / (sqrt(sum/dim + eps))
float
rms
=
1.
f
/
std
::
sqrt
(
sum_squared
/
(
float
)(
dim
)
+
info
->
epsilon
);
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
residual_out_ptr
[
k
]);
float
val
;
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
val
=
sum_val
*
w
[
k
]
*
rms
;
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
T
>::
value
||
std
::
is_same_v
<
Tw
,
fp16_t
>
||
std
::
is_same_v
<
Tw
,
bf16_t
>
)
{
val
=
sum_val
*
utils
::
cast
<
float
>
(
w
[
k
])
*
rms
;
}
else
{
std
::
abort
();
}
y_ptr
[
k
]
=
utils
::
cast
<
T
>
(
val
);
}
}
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
)
const
{
if
(
_info
.
atype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
fp16_t
*
)
weight
,
(
fp16_t
*
)
residual_out
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
float
*
)
weight
,
(
fp16_t
*
)
residual_out
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
bf16_t
*
)
weight
,
(
fp16_t
*
)
residual_out
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_BF16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
bf16_t
*
)
weight
,
(
bf16_t
*
)
residual_out
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
float
*
)
weight
,
(
bf16_t
*
)
residual_out
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
fp16_t
*
)
weight
,
(
bf16_t
*
)
residual_out
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
const
float
*
)
a
,
(
const
float
*
)
b
,
(
const
float
*
)
weight
,
(
float
*
)
residual_out
));
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F64
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
const
double
*
)
a
,
(
const
double
*
)
b
,
(
const
double
*
)
weight
,
(
double
*
)
residual_out
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::add_rms_norm::cpu
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h
0 → 100644
View file @
3b5afffe
#ifndef __ADD_RMS_NORM_CPU_H__
#define __ADD_RMS_NORM_CPU_H__
#include "../add_rms_norm.h"
DESCRIPTOR
(
cpu
)
#endif
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
0 → 100644
View file @
3b5afffe
#ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
#define __ADD_RMS_NORM_CUDA_KERNEL_H__
#include <cub/block/block_reduce.cuh>
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
__device__
void
add_rmsnormBlock
(
Tdata
*
__restrict__
y
,
Tdata
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
Tdata
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
Tdata
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
Tweight
*
__restrict__
w
,
size_t
nhead
,
size_t
dim
,
float
epsilon
)
{
// Each block takes care of one head in one batch
// Each thread deals with every block_size element in the row
size_t
batch_idx
=
blockIdx
.
x
/
nhead
;
size_t
head_idx
=
blockIdx
.
x
%
nhead
;
auto
y_ptr
=
y
+
batch_idx
*
stride_y_batch
+
head_idx
*
stride_y_nhead
;
auto
a_ptr
=
a
+
batch_idx
*
stride_a_batch
+
head_idx
*
stride_a_nhead
;
auto
b_ptr
=
b
+
batch_idx
*
stride_b_batch
+
head_idx
*
stride_b_nhead
;
auto
w_ptr
=
w
;
Tdata
*
residual_out_ptr
=
residual_out
+
batch_idx
*
stride_residual_out_batch
+
head_idx
*
stride_residual_out_nhead
;
// Compute add(a, b) and sum of squares in one pass
Tcompute
sum_squared
=
0
;
for
(
size_t
i
=
threadIdx
.
x
;
i
<
dim
;
i
+=
BLOCK_SIZE
)
{
Tcompute
sum_val
=
Tcompute
(
a_ptr
[
i
])
+
Tcompute
(
b_ptr
[
i
]);
residual_out_ptr
[
i
]
=
Tdata
(
sum_val
);
// Store add result
sum_squared
+=
sum_val
*
sum_val
;
}
// Block-reduce sum of squares
using
BlockReduce
=
cub
::
BlockReduce
<
Tcompute
,
BLOCK_SIZE
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
sum_squared
=
BlockReduce
(
temp_storage
).
Sum
(
sum_squared
);
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
__shared__
Tcompute
rms
;
if
(
threadIdx
.
x
==
0
)
{
rms
=
Tcompute
(
rsqrtf
(
sum_squared
/
Tcompute
(
dim
)
+
epsilon
));
}
__syncthreads
();
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for
(
size_t
i
=
threadIdx
.
x
;
i
<
dim
;
i
+=
BLOCK_SIZE
)
{
Tcompute
sum_val
=
Tcompute
(
residual_out_ptr
[
i
]);
// Reuse stored value
y_ptr
[
i
]
=
Tdata
(
sum_val
*
Tcompute
(
w_ptr
[
i
])
*
rms
);
}
}
#endif
src/infiniop/ops/add_rms_norm/info.h
0 → 100644
View file @
3b5afffe
#ifndef __ADD_RMS_NORM_INFO_H__
#define __ADD_RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace
op
::
add_rms_norm
{
class
AddRMSNormInfo
{
AddRMSNormInfo
()
=
default
;
public:
infiniDtype_t
wtype
;
infiniDtype_t
atype
;
float
epsilon
;
std
::
vector
<
size_t
>
shape
;
std
::
vector
<
ptrdiff_t
>
y_strides
;
std
::
vector
<
ptrdiff_t
>
a_strides
;
std
::
vector
<
ptrdiff_t
>
b_strides
;
std
::
vector
<
ptrdiff_t
>
residual_out_strides
;
bool
has_residual_out
;
size_t
ndim
()
const
{
return
shape
.
size
();
}
size_t
dim
()
const
{
return
shape
[
ndim
()
-
1
];
}
static
utils
::
Result
<
AddRMSNormInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
auto
atype
=
y_desc
->
dtype
();
auto
wtype
=
weight_desc
->
dtype
();
// Check that all input tensors have the same dtype
if
(
a_desc
->
dtype
()
!=
atype
||
b_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
atype
==
INFINI_DTYPE_F16
||
atype
==
INFINI_DTYPE_BF16
)
{
// For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32
if
(
wtype
!=
atype
&&
wtype
!=
INFINI_DTYPE_F32
&&
wtype
!=
INFINI_DTYPE_BF16
&&
wtype
!=
INFINI_DTYPE_F16
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
atype
==
INFINI_DTYPE_F32
||
atype
==
INFINI_DTYPE_F64
)
{
// For FP32/FP64, activations and weights must be of the same type
if
(
atype
!=
wtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
const
size_t
y_ndim
=
y_desc
->
ndim
();
const
size_t
a_ndim
=
a_desc
->
ndim
();
const
size_t
b_ndim
=
b_desc
->
ndim
();
const
size_t
w_ndim
=
weight_desc
->
ndim
();
if
(
y_ndim
!=
a_ndim
||
y_ndim
!=
b_ndim
||
w_ndim
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
batch
=
1
;
size_t
nhead
=
1
;
size_t
dim
=
0
;
if
(
y_ndim
==
2
)
{
batch
=
y_desc
->
dim
(
0
);
dim
=
y_desc
->
dim
(
1
);
if
(
a_desc
->
dim
(
0
)
!=
batch
||
a_desc
->
dim
(
1
)
!=
dim
||
b_desc
->
dim
(
0
)
!=
batch
||
b_desc
->
dim
(
1
)
!=
dim
||
weight_desc
->
dim
(
0
)
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
else
if
(
y_ndim
==
3
)
{
batch
=
y_desc
->
dim
(
0
);
nhead
=
y_desc
->
dim
(
1
);
dim
=
y_desc
->
dim
(
2
);
if
(
a_desc
->
dim
(
0
)
!=
batch
||
a_desc
->
dim
(
1
)
!=
nhead
||
a_desc
->
dim
(
2
)
!=
dim
||
b_desc
->
dim
(
0
)
!=
batch
||
b_desc
->
dim
(
1
)
!=
nhead
||
b_desc
->
dim
(
2
)
!=
dim
||
weight_desc
->
dim
(
0
)
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
else
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// Check contiguity of the last dimension
if
(
y_desc
->
stride
(
y_ndim
-
1
)
!=
1
||
a_desc
->
stride
(
a_ndim
-
1
)
!=
1
||
b_desc
->
stride
(
b_ndim
-
1
)
!=
1
||
weight_desc
->
stride
(
w_ndim
-
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
// residual_out_desc is required (always needed for fused operator)
if
(
residual_out_desc
==
nullptr
)
{
return
INFINI_STATUS_BAD_PARAM
;
}
const
size_t
residual_out_ndim
=
residual_out_desc
->
ndim
();
if
(
residual_out_ndim
!=
y_ndim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
residual_out_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
// Check shape matches
for
(
size_t
i
=
0
;
i
<
y_ndim
;
i
++
)
{
if
(
residual_out_desc
->
dim
(
i
)
!=
y_desc
->
dim
(
i
))
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
if
(
residual_out_desc
->
stride
(
residual_out_ndim
-
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
AddRMSNormInfo
info
;
info
.
wtype
=
wtype
;
info
.
atype
=
atype
;
info
.
epsilon
=
epsilon
;
info
.
shape
=
y_desc
->
shape
();
info
.
y_strides
=
y_desc
->
strides
();
info
.
a_strides
=
a_desc
->
strides
();
info
.
b_strides
=
b_desc
->
strides
();
info
.
has_residual_out
=
true
;
// Always true now
info
.
residual_out_strides
=
residual_out_desc
->
strides
();
return
utils
::
Result
<
AddRMSNormInfo
>
(
info
);
}
};
}
// namespace op::add_rms_norm
#endif // __ADD_RMS_NORM_INFO_H__
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
0 → 100644
View file @
3b5afffe
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "add_rms_norm_nvidia.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
INFINIOP_CUDA_KERNEL
add_rmsnormKernel
(
Tdata
*
__restrict__
y
,
Tdata
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
Tdata
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
Tdata
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
Tweight
*
__restrict__
w
,
size_t
nhead
,
size_t
dim
,
float
epsilon
)
{
add_rmsnormBlock
<
BLOCK_SIZE
,
Tcompute
>
(
y
,
residual_out
,
stride_y_batch
,
stride_y_nhead
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
w
,
nhead
,
dim
,
epsilon
);
}
namespace
op
::
add_rms_norm
::
nvidia
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
,
residual_out_desc
);
CHECK_RESULT
(
result
);
auto
info
=
result
.
take
();
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
std
::
move
(
info
),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
// launch kernel with different data types
template
<
unsigned
int
BLOCK_SIZE
>
infiniStatus_t
launchKernel
(
uint32_t
batch_size
,
size_t
nhead
,
size_t
dim
,
void
*
y
,
infiniDtype_t
atype
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
void
*
residual_out
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
void
*
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
void
*
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
void
*
w
,
infiniDtype_t
wtype
,
float
epsilon
,
cudaStream_t
cuda_stream
)
{
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F16
)
{
LAUNCH_KERNEL
(
half
,
half
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_BF16
)
{
LAUNCH_KERNEL
(
half
,
__nv_bfloat16
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL
(
half
,
float
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_BF16
)
{
LAUNCH_KERNEL
(
__nv_bfloat16
,
__nv_bfloat16
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_F16
)
{
LAUNCH_KERNEL
(
__nv_bfloat16
,
half
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL
(
__nv_bfloat16
,
float
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F32
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL
(
float
,
float
,
float
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
#undef LAUNCH_KERNEL
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
)
const
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
auto
stride_a_batch
=
_info
.
a_strides
[
0
];
auto
stride_a_nhead
=
_info
.
a_strides
[
1
];
auto
stride_b_batch
=
_info
.
b_strides
[
0
];
auto
stride_b_nhead
=
_info
.
b_strides
[
1
];
auto
stride_y_batch
=
_info
.
y_strides
[
0
];
auto
stride_y_nhead
=
_info
.
y_strides
[
1
];
auto
stride_residual_out_batch
=
_info
.
residual_out_strides
[
0
];
auto
stride_residual_out_nhead
=
_info
.
residual_out_strides
[
1
];
auto
dim
=
_info
.
dim
();
uint32_t
batch_size
=
static_cast
<
uint32_t
>
(
_info
.
shape
[
0
]);
size_t
nhead
=
_info
.
shape
.
size
()
>
2
?
_info
.
shape
[
1
]
:
1
;
auto
cuda_stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream
);
// launch kernel with different block sizes
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_1024
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y_batch
,
stride_y_nhead
,
residual_out
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
weight
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_512
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y_batch
,
stride_y_nhead
,
residual_out
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
weight
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CHECK_STATUS
(
launchKernel
<
CUDA_BLOCK_SIZE_4096
>
(
batch_size
,
nhead
,
dim
,
y
,
_info
.
atype
,
stride_y_batch
,
stride_y_nhead
,
residual_out
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
weight
,
_info
.
wtype
,
_info
.
epsilon
,
cuda_stream
));
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::add_rms_norm::nvidia
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cuh
0 → 100644
View file @
3b5afffe
#ifndef __ADD_RMS_NORM_NVIDIA_CUDA_H__
#define __ADD_RMS_NORM_NVIDIA_CUDA_H__
#include "../add_rms_norm.h"
DESCRIPTOR
(
nvidia
)
#endif
src/infiniop/ops/add_rms_norm/operator.cc
0 → 100644
View file @
3b5afffe
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/add_rms_norm.h"
#ifdef ENABLE_CPU_API
#include "cpu/add_rms_norm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#include "nvidia/add_rms_norm_nvidia.cuh"
#endif
#ifdef ENABLE_ASCEND_API
// TODO: Add Ascend implementation
// #include "ascend/add_rms_norm_aclnn.h"
#endif
#ifdef ENABLE_CAMBRICON_API
// TODO: Add Cambricon implementation
// #include "bang/add_rms_norm_bang.h"
#endif
#ifdef ENABLE_METAX_API
// TODO: Add Metax implementation
// #include "metax/add_rms_norm_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
// TODO: Add Moore implementation
// #include "moore/add_rms_norm_moore.h"
#endif
#ifdef ENABLE_KUNLUN_API
// TODO: Add Kunlun implementation
// #include "kunlun/add_rms_norm_kunlun.h"
#endif
__C
infiniStatus_t
infiniopCreateAddRMSNormDescriptor
(
infiniopHandle_t
handle
,
infiniopAddRMSNormDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
,
infiniopTensorDescriptor_t
residual_out_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::add_rms_norm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
a_desc, \
b_desc, \
weight_desc, \
epsilon, \
residual_out_desc)
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CREATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
// CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetAddRMSNormWorkspaceSize
(
infiniopAddRMSNormDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
GET
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
// GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopAddRMSNorm
(
infiniopAddRMSNormDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
residual_out
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
// CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroyAddRMSNormDescriptor
(
infiniopAddRMSNormDescriptor_t
desc
)
{
if
(
desc
==
nullptr
)
{
return
INFINI_STATUS_SUCCESS
;
}
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
DESTROY
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
DESTROY
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_KUNLUN_API
// DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DESTROY
}
test/infinicore/ops/add_rms_norm.py
0 → 100644
View file @
3b5afffe
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
import
torch
import
infinicore
from
framework
import
(
BaseOperatorTest
,
TensorSpec
,
TestCase
,
GenericTestRunner
,
is_broadcast
,
)
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (y_shape, a_shape, b_shape, w_shape, y_strides, a_strides, b_strides)
_TEST_CASES_DATA
=
[
# Basic cases
((
1
,
4
),
(
1
,
4
),
(
1
,
4
),
(
4
,),
None
,
None
,
None
),
((
2
,
4
),
(
2
,
4
),
(
2
,
4
),
(
4
,),
None
,
None
,
None
),
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
None
,
None
,
None
),
# Strided cases
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
(
12
,
8
,
1
),
(
12
,
8
,
1
),
(
12
,
8
,
1
)),
# Large tensors
((
16
,
2048
),
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
None
),
((
16
,
2048
),
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
(
4096
,
1
)),
((
15
,
3584
),
(
15
,
3584
),
(
15
,
3584
),
(
3584
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
)),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
)),
]
# Tolerance configuration
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
2e-3
,
"rtol"
:
2e-3
},
infinicore
.
bfloat16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
infinicore
.
float32
:
{
"atol"
:
1e-5
,
"rtol"
:
1e-4
},
}
# Data types for individual tensors
_INPUT_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
]
_WEIGHT_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
# EPSILON constant for AddRMSNorm
_EPSILON
=
1e-5
def
parse_test_cases
():
"""
Parse AddRMSNorm test case data and return list of TestCase objects.
Format: (y_shape, a_shape, b_shape, w_shape, y_strides, a_strides, b_strides)
"""
test_cases
=
[]
for
data
in
_TEST_CASES_DATA
:
y_shape
=
data
[
0
]
# Output shape
a_shape
=
data
[
1
]
# First input shape
b_shape
=
data
[
2
]
# Second input shape
w_shape
=
data
[
3
]
# Weight shape (1D)
y_strides
=
data
[
4
]
if
len
(
data
)
>
4
else
None
a_strides
=
data
[
5
]
if
len
(
data
)
>
5
else
None
b_strides
=
data
[
6
]
if
len
(
data
)
>
6
else
None
# Check if tensors support in-place operations
a_supports_inplace
=
not
is_broadcast
(
a_strides
)
b_supports_inplace
=
not
is_broadcast
(
b_strides
)
y_supports_inplace
=
not
is_broadcast
(
y_strides
)
# Generate test cases for all dtype combinations
for
input_dtype
in
_INPUT_DTYPES
:
for
weight_dtype
in
_WEIGHT_DTYPES
:
# Use input dtype tolerance for output
tolerance
=
_TOLERANCE_MAP
.
get
(
input_dtype
,
{
"atol"
:
1e-5
,
"rtol"
:
1e-4
}
)
# Create typed tensor specs
a_spec
=
TensorSpec
.
from_tensor
(
a_shape
,
a_strides
,
input_dtype
)
b_spec
=
TensorSpec
.
from_tensor
(
b_shape
,
b_strides
,
input_dtype
)
w_spec
=
TensorSpec
.
from_tensor
(
w_shape
,
None
,
weight_dtype
)
# Weight is always contiguous
y_spec
=
TensorSpec
.
from_tensor
(
y_shape
,
y_strides
,
input_dtype
)
# Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result)
residual_out_spec
=
TensorSpec
.
from_tensor
(
a_shape
,
a_strides
,
input_dtype
)
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"epsilon"
:
_EPSILON
},
output_specs
=
[
y_spec
,
residual_out_spec
],
# Two outputs
comparison_target
=
None
,
tolerance
=
tolerance
,
output_count
=
2
,
# Two outputs: normalized_result and add_result
description
=
f
"AddRMSNorm - OUT_OF_PLACE"
,
)
)
# Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w))
if
y_supports_inplace
:
residual_out_spec
=
TensorSpec
.
from_tensor
(
a_shape
,
a_strides
,
input_dtype
)
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"epsilon"
:
_EPSILON
,
"out"
:
(
y_spec
,
residual_out_spec
)},
output_specs
=
[
y_spec
,
residual_out_spec
],
# Two outputs
comparison_target
=
"out"
,
tolerance
=
tolerance
,
output_count
=
2
,
description
=
f
"AddRMSNorm - INPLACE(out)"
,
)
)
return
test_cases
class
OpTest
(
BaseOperatorTest
):
"""AddRMSNorm operator test with simplified implementation"""
def
__init__
(
self
):
super
().
__init__
(
"AddRMSNorm"
)
def
get_test_cases
(
self
):
return
parse_test_cases
()
def
torch_operator
(
self
,
a
,
b
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
**
kwargs
):
"""PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)"""
input_dtype
=
a
.
dtype
# Compute add(a, b)
sum_tensor
=
a
.
to
(
torch
.
float32
)
+
b
.
to
(
torch
.
float32
)
weight_fp32
=
weight
.
to
(
torch
.
float32
)
# Calculate RMSNorm: (a + b) * weight / sqrt(mean((a+b)^2) + epsilon)
variance
=
sum_tensor
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
normalized_result
=
sum_tensor
*
torch
.
rsqrt
(
variance
+
epsilon
)
*
weight_fp32
# Convert back to original dtype
normalized_result
=
normalized_result
.
to
(
input_dtype
)
add_result
=
sum_tensor
.
to
(
input_dtype
)
if
out
is
not
None
:
# For in-place operations, we need to handle the output tuple
if
isinstance
(
out
,
(
tuple
,
list
))
and
len
(
out
)
==
2
:
out
[
0
].
copy_
(
normalized_result
)
out
[
1
].
copy_
(
add_result
)
return
tuple
(
out
)
else
:
# Single output - just return normalized result for backward compatibility
out
.
copy_
(
normalized_result
)
return
out
return
(
normalized_result
,
add_result
)
def
infinicore_operator
(
self
,
a
,
b
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
**
kwargs
):
"""InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)"""
return
infinicore
.
add_rms_norm
(
a
,
b
,
weight
,
epsilon
,
out
=
out
)
def
main
():
"""Main entry point"""
runner
=
GenericTestRunner
(
OpTest
)
runner
.
run_and_exit
()
if
__name__
==
"__main__"
:
main
()
test/infiniop/add_rms_norm.py
0 → 100644
View file @
3b5afffe
import
torch
import
ctypes
from
ctypes
import
c_uint64
from
libinfiniop
import
(
LIBINFINIOP
,
TestTensor
,
get_test_devices
,
check_error
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
TestWorkspace
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_
=
[
# y_shape, a_shape, b_shape, w_shape, y_stride, a_stride, b_stride
((
1
,
4
),
(
1
,
4
),
(
1
,
4
),
(
4
,),
None
,
None
,
None
),
((
2
,
4
),
(
2
,
4
),
(
2
,
4
),
(
4
,),
None
,
None
,
None
),
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
None
,
None
,
None
),
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
(
12
,
8
,
1
),
(
12
,
8
,
1
),
(
12
,
8
,
1
)),
((
16
,
2048
),
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
None
),
((
16
,
2048
),
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
(
4096
,
1
)),
((
15
,
3584
),
(
15
,
3584
),
(
15
,
3584
),
(
3584
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
)),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
)),
((
15
,
3584
),
(
15
,
3584
),
(
15
,
3584
),
(
3584
,),
None
,
None
,
None
),
((
15
,
8192
),
(
15
,
8192
),
(
15
,
8192
),
(
8192
,),
None
,
None
,
None
),
]
# w (weight) types
# Note: 'None' means the same as input dtype
_WEIGHT_DTYPES
=
[
None
,
InfiniDtype
.
F32
,
InfiniDtype
.
F16
,
InfiniDtype
.
BF16
]
# a, b types used for testing
_TENSOR_DTYPES
=
[
InfiniDtype
.
F16
,
InfiniDtype
.
BF16
]
# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES
=
[
test_case
+
(
w_dtype
,)
for
test_case
in
_TEST_CASES_
for
w_dtype
in
_WEIGHT_DTYPES
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
2e-3
,
"rtol"
:
2e-3
},
InfiniDtype
.
BF16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
def
add_rms_norm
(
ans
,
a
,
b
,
w
,
eps
):
input_dtype
=
a
.
dtype
# Compute add(a, b)
sum_tensor
=
a
.
to
(
torch
.
float32
)
+
b
.
to
(
torch
.
float32
)
# Compute RMS normalization
scale
=
sum_tensor
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
).
add_
(
eps
).
rsqrt_
()
ans
.
set_
((
sum_tensor
.
mul_
(
scale
).
mul_
(
w
.
to
(
torch
.
float32
))).
to
(
input_dtype
))
def
test
(
handle
,
device
,
y_shape
,
a_shape
,
b_shape
,
w_shape
,
y_stride
,
a_stride
,
b_stride
,
w_dtype
=
InfiniDtype
.
F32
,
dtype
=
InfiniDtype
.
F16
,
sync
=
None
,
):
w_dtype
=
w_dtype
if
w_dtype
else
dtype
print
(
f
"Testing AddRMSNorm on
{
InfiniDeviceNames
[
device
]
}
with y_shape:
{
y_shape
}
a_shape:
{
a_shape
}
b_shape:
{
b_shape
}
w_shape:
{
w_shape
}
"
f
" y_stride:
{
y_stride
}
a_stride:
{
a_stride
}
b_stride:
{
b_stride
}
w_dtype:
{
InfiniDtypeNames
[
w_dtype
]
}
dtype:
{
InfiniDtypeNames
[
dtype
]
}
"
)
y
=
TestTensor
(
y_shape
,
y_stride
,
dtype
,
device
,
mode
=
"ones"
)
residual_out
=
TestTensor
(
a_shape
,
a_stride
,
dtype
,
device
,
mode
=
"ones"
)
a
=
TestTensor
(
a_shape
,
a_stride
,
dtype
,
device
,
scale
=
0.01
)
b
=
TestTensor
(
b_shape
,
b_stride
,
dtype
,
device
,
scale
=
0.01
)
w
=
TestTensor
(
w_shape
,
None
,
w_dtype
,
device
)
eps
=
1e-6
add_rms_norm
(
y
.
torch_tensor
(),
a
.
torch_tensor
(),
b
.
torch_tensor
(),
w
.
torch_tensor
(),
eps
)
if
sync
is
not
None
:
sync
()
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
LIBINFINIOP
.
infiniopCreateAddRMSNormDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
y
.
descriptor
,
a
.
descriptor
,
b
.
descriptor
,
w
.
descriptor
,
eps
,
residual_out
.
descriptor
,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for
tensor
in
[
a
,
b
,
y
,
w
,
residual_out
]:
tensor
.
destroy_desc
()
workspace_size
=
c_uint64
(
0
)
check_error
(
LIBINFINIOP
.
infiniopGetAddRMSNormWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
y
.
device
)
def
lib_add_rms_norm
():
check_error
(
LIBINFINIOP
.
infiniopAddRMSNorm
(
descriptor
,
workspace
.
data
(),
workspace_size
.
value
,
y
.
data
(),
a
.
data
(),
b
.
data
(),
w
.
data
(),
residual_out
.
data
(),
None
,
)
)
lib_add_rms_norm
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
# Verify normalized result (y)
if
DEBUG
:
debug
(
y
.
actual_tensor
(),
y
.
torch_tensor
(),
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
y
.
actual_tensor
(),
y
.
torch_tensor
(),
atol
=
atol
,
rtol
=
rtol
)
# Verify add result (residual_out) - should be a + b
expected_residual
=
a
.
torch_tensor
().
to
(
torch
.
float32
)
+
b
.
torch_tensor
().
to
(
torch
.
float32
)
expected_residual
=
expected_residual
.
to
(
a
.
torch_tensor
().
dtype
)
if
DEBUG
:
debug
(
residual_out
.
actual_tensor
(),
expected_residual
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
residual_out
.
actual_tensor
(),
expected_residual
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
add_rms_norm
(
y
.
torch_tensor
(),
a
.
torch_tensor
(),
b
.
torch_tensor
(),
w
.
torch_tensor
(),
eps
),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_add_rms_norm
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
LIBINFINIOP
.
infiniopDestroyAddRMSNormDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
args
=
get_args
()
# Configure testing options
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
# Execute tests
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment