Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
7d60e5b8
Commit
7d60e5b8
authored
Dec 23, 2025
by
zhuyue
Committed by
gongchensu
Dec 24, 2025
Browse files
增加cpu的add rms_norm算子,c++和python接口
parent
12cde8eb
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1089 additions
and
0 deletions
+1089
-0
include/infinicore/ops.hpp
include/infinicore/ops.hpp
+1
-0
include/infinicore/ops/add_rms_norm.hpp
include/infinicore/ops/add_rms_norm.hpp
+16
-0
include/infiniop.h
include/infiniop.h
+1
-0
include/infiniop/ops/add_rms_norm.h
include/infiniop/ops/add_rms_norm.h
+30
-0
python/infinicore/__init__.py
python/infinicore/__init__.py
+2
-0
python/infinicore/ops/add_rms_norm.py
python/infinicore/ops/add_rms_norm.py
+11
-0
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
+28
-0
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
+50
-0
src/infinicore/pybind11/ops.hpp
src/infinicore/pybind11/ops.hpp
+2
-0
src/infinicore/pybind11/ops/add_rms_norm.hpp
src/infinicore/pybind11/ops/add_rms_norm.hpp
+48
-0
src/infiniop/ops/add_rms_norm/add_rms_norm.h
src/infiniop/ops/add_rms_norm/add_rms_norm.h
+51
-0
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
+143
-0
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h
+7
-0
src/infiniop/ops/add_rms_norm/info.h
src/infiniop/ops/add_rms_norm/info.h
+112
-0
src/infiniop/ops/add_rms_norm/operator.cc
src/infiniop/ops/add_rms_norm/operator.cc
+187
-0
test/infinicore/ops/add_rms_norm.py
test/infinicore/ops/add_rms_norm.py
+190
-0
test/infiniop/add_rms_norm.py
test/infiniop/add_rms_norm.py
+173
-0
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+37
-0
No files found.
include/infinicore/ops.hpp
View file @
7d60e5b8
#pragma once
#pragma once
#include "ops/add.hpp"
#include "ops/add.hpp"
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
#include "ops/matmul.hpp"
...
...
include/infinicore/ops/add_rms_norm.hpp
0 → 100644
View file @
7d60e5b8
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
class
AddRMSNorm
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
float
);
static
void
execute
(
Tensor
y
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
Tensor
add_rms_norm
(
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
void
add_rms_norm_
(
Tensor
y
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
}
// namespace infinicore::op
include/infiniop.h
View file @
7d60e5b8
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "infiniop/handle.h"
#include "infiniop/handle.h"
#include "infiniop/ops/add.h"
#include "infiniop/ops/add.h"
#include "infiniop/ops/add_rms_norm.h"
#include "infiniop/ops/attention.h"
#include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/clip.h"
#include "infiniop/ops/clip.h"
...
...
include/infiniop/ops/add_rms_norm.h
0 → 100644
View file @
7d60e5b8
#ifndef __INFINIOP_ADD_RMS_NORM_API_H__
#define __INFINIOP_ADD_RMS_NORM_API_H__
#include "../operator_descriptor.h"
typedef
struct
InfiniopDescriptor
*
infiniopAddRMSNormDescriptor_t
;
__C
__export
infiniStatus_t
infiniopCreateAddRMSNormDescriptor
(
infiniopHandle_t
handle
,
infiniopAddRMSNormDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
);
__C
__export
infiniStatus_t
infiniopGetAddRMSNormWorkspaceSize
(
infiniopAddRMSNormDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopAddRMSNorm
(
infiniopAddRMSNormDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyAddRMSNormDescriptor
(
infiniopAddRMSNormDescriptor_t
desc
);
#endif
python/infinicore/__init__.py
View file @
7d60e5b8
...
@@ -40,6 +40,7 @@ from infinicore.dtype import (
...
@@ -40,6 +40,7 @@ from infinicore.dtype import (
uint8
,
uint8
,
)
)
from
infinicore.ops.add
import
add
from
infinicore.ops.add
import
add
from
infinicore.ops.add_rms_norm
import
add_rms_norm
from
infinicore.ops.attention
import
attention
from
infinicore.ops.attention
import
attention
from
infinicore.ops.matmul
import
matmul
from
infinicore.ops.matmul
import
matmul
from
infinicore.ops.mul
import
mul
from
infinicore.ops.mul
import
mul
...
@@ -102,6 +103,7 @@ __all__ = [
...
@@ -102,6 +103,7 @@ __all__ = [
"uint8"
,
"uint8"
,
# Operations.
# Operations.
"add"
,
"add"
,
"add_rms_norm"
,
"attention"
,
"attention"
,
"matmul"
,
"matmul"
,
"mul"
,
"mul"
,
...
...
python/infinicore/ops/add_rms_norm.py
0 → 100644
View file @
7d60e5b8
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
def
add_rms_norm
(
a
,
b
,
weight
,
epsilon
=
1e-5
,
*
,
out
=
None
):
if
out
is
None
:
return
Tensor
(
_infinicore
.
add_rms_norm
(
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
))
_infinicore
.
add_rms_norm_
(
out
.
_underlying
,
a
.
_underlying
,
b
.
_underlying
,
weight
.
_underlying
,
epsilon
)
return
out
src/infinicore/ops/add_rms_norm/add_rms_norm.cc
0 → 100644
View file @
7d60e5b8
#include "infinicore/ops/add_rms_norm.hpp"
#include "../../utils.hpp"
namespace
infinicore
::
op
{
common
::
OpDispatcher
<
AddRMSNorm
::
schema
>
&
AddRMSNorm
::
dispatcher
()
{
static
common
::
OpDispatcher
<
AddRMSNorm
::
schema
>
dispatcher_
;
return
dispatcher_
;
};
void
AddRMSNorm
::
execute
(
Tensor
y
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
INFINICORE_ASSERT_TENSORS_SAME_DEVICE
(
y
,
a
,
b
,
weight
);
infinicore
::
context
::
setDevice
(
y
->
device
());
dispatcher
().
lookup
(
y
->
device
().
getType
())(
y
,
a
,
b
,
weight
,
epsilon
);
}
Tensor
add_rms_norm
(
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
auto
y
=
Tensor
::
empty
(
a
->
shape
(),
a
->
dtype
(),
a
->
device
());
add_rms_norm_
(
y
,
a
,
b
,
weight
,
epsilon
);
return
y
;
}
void
add_rms_norm_
(
Tensor
y
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
AddRMSNorm
::
execute
(
y
,
a
,
b
,
weight
,
epsilon
);
}
}
// namespace infinicore::op
src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
0 → 100644
View file @
7d60e5b8
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add_rms_norm.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>
namespace
infinicore
::
op
::
add_rms_norm_impl
::
infiniop
{
thread_local
common
::
OpCache
<
size_t
,
infiniopAddRMSNormDescriptor_t
>
caches
(
100
,
// capacity
[](
infiniopAddRMSNormDescriptor_t
&
desc
)
{
if
(
desc
!=
nullptr
)
{
INFINICORE_CHECK_ERROR
(
infiniopDestroyAddRMSNormDescriptor
(
desc
));
desc
=
nullptr
;
}
});
void
calculate
(
Tensor
y
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
)
{
size_t
seed
=
hash_combine
(
y
,
a
,
b
,
weight
,
epsilon
);
auto
device
=
context
::
getDevice
();
auto
&
cache
=
caches
.
getCache
(
device
);
auto
desc_opt
=
cache
.
get
(
seed
);
infiniopAddRMSNormDescriptor_t
desc
=
nullptr
;
if
(
!
desc_opt
)
{
INFINICORE_CHECK_ERROR
(
infiniopCreateAddRMSNormDescriptor
(
context
::
getInfiniopHandle
(
device
),
&
desc
,
y
->
desc
(),
a
->
desc
(),
b
->
desc
(),
weight
->
desc
(),
epsilon
));
cache
.
put
(
seed
,
desc
);
}
else
{
desc
=
*
desc_opt
;
}
size_t
workspace_size
=
0
;
INFINICORE_CHECK_ERROR
(
infiniopGetAddRMSNormWorkspaceSize
(
desc
,
&
workspace_size
));
std
::
shared_ptr
<
Memory
>
workspace
=
context
::
allocateMemory
(
workspace_size
);
INFINICORE_CHECK_ERROR
(
infiniopAddRMSNorm
(
desc
,
workspace
->
data
(),
workspace_size
,
y
->
data
(),
a
->
data
(),
b
->
data
(),
weight
->
data
(),
context
::
getStream
()));
}
static
bool
registered
=
[]()
{
AddRMSNorm
::
dispatcher
().
registerAll
(
&
calculate
,
false
);
return
true
;
}();
}
// namespace infinicore::op::add_rms_norm_impl::infiniop
src/infinicore/pybind11/ops.hpp
View file @
7d60e5b8
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
#include "ops/add.hpp"
#include "ops/add.hpp"
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/embedding.hpp"
...
@@ -22,6 +23,7 @@ namespace infinicore::ops {
...
@@ -22,6 +23,7 @@ namespace infinicore::ops {
inline
void
bind
(
py
::
module
&
m
)
{
inline
void
bind
(
py
::
module
&
m
)
{
bind_add
(
m
);
bind_add
(
m
);
bind_add_rms_norm
(
m
);
bind_attention
(
m
);
bind_attention
(
m
);
bind_causal_softmax
(
m
);
bind_causal_softmax
(
m
);
bind_random_sample
(
m
);
bind_random_sample
(
m
);
...
...
src/infinicore/pybind11/ops/add_rms_norm.hpp
0 → 100644
View file @
7d60e5b8
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/add_rms_norm.hpp"
namespace
py
=
pybind11
;
namespace
infinicore
::
ops
{
inline
void
bind_add_rms_norm
(
py
::
module
&
m
)
{
m
.
def
(
"add_rms_norm"
,
&
op
::
add_rms_norm
,
py
::
arg
(
"a"
),
py
::
arg
(
"b"
),
py
::
arg
(
"weight"
),
py
::
arg
(
"epsilon"
)
=
1e-5
f
,
R"doc(Fused Add and RMS Normalization.
Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
Returns:
Normalized tensor: RMSNorm(a + b) * weight
)doc"
);
m
.
def
(
"add_rms_norm_"
,
&
op
::
add_rms_norm_
,
py
::
arg
(
"y"
),
py
::
arg
(
"a"
),
py
::
arg
(
"b"
),
py
::
arg
(
"weight"
),
py
::
arg
(
"epsilon"
)
=
1e-5
f
,
R"doc(In-place Fused Add and RMS Normalization.
Args:
y: Output tensor
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
)doc"
);
}
}
// namespace infinicore::ops
src/infiniop/ops/add_rms_norm/add_rms_norm.h
0 → 100644
View file @
7d60e5b8
#ifndef ADD_RMS_NORM_H
#define ADD_RMS_NORM_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::add_rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
AddRMSNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
AddRMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *a, \
const void *b, \
const void *weight, \
void *stream) const; \
}; \
}
#endif // ADD_RMS_NORM_H
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc
0 → 100644
View file @
7d60e5b8
#include "add_rms_norm_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
namespace
op
::
add_rms_norm
::
cpu
{
Descriptor
::~
Descriptor
()
{}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
)
{
auto
result
=
AddRMSNormInfo
::
create
(
y_desc
,
a_desc
,
b_desc
,
weight_desc
,
epsilon
);
CHECK_RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
nullptr
,
result
.
take
(),
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
T
>
infiniStatus_t
add_rmsnorm
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
T
*
w
)
{
const
size_t
batch_size
=
info
->
shape
[
0
];
const
size_t
nhead
=
info
->
ndim
()
>
2
?
info
->
shape
[
1
]
:
1
;
const
size_t
dim
=
info
->
dim
();
const
ptrdiff_t
total_blocks
=
static_cast
<
ptrdiff_t
>
(
batch_size
*
nhead
);
#pragma omp parallel for
for
(
ptrdiff_t
block_idx
=
0
;
block_idx
<
total_blocks
;
++
block_idx
)
{
const
size_t
i
=
block_idx
/
nhead
;
// batch index
const
size_t
j
=
block_idx
%
nhead
;
// head index
const
T
*
a_ptr
=
a
+
i
*
info
->
a_strides
[
0
]
+
j
*
info
->
a_strides
[
1
];
const
T
*
b_ptr
=
b
+
i
*
info
->
b_strides
[
0
]
+
j
*
info
->
b_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
// First, compute add(a, b) and store sum values
// We'll compute RMS norm directly on the sum
T
sum_squared
=
(
T
)
0
;
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
T
sum_val
=
a_ptr
[
k
]
+
b_ptr
[
k
];
sum_squared
+=
sum_val
*
sum_val
;
}
// Compute RMS: 1 / (sqrt(mean(sum^2) + eps))
// Note: mean = sum_squared / dim
T
rms
=
(
T
)
1
/
std
::
sqrt
(
sum_squared
/
(
T
)(
dim
)
+
(
T
)(
info
->
epsilon
));
// Apply normalization: y = (a + b) * w * rms
// Recompute sum to avoid storing temporary array
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
T
sum_val
=
a_ptr
[
k
]
+
b_ptr
[
k
];
y_ptr
[
k
]
=
sum_val
*
w
[
k
]
*
rms
;
}
}
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
T
,
typename
Tw
>
infiniStatus_t
add_rmsnormHalfPrecision
(
const
AddRMSNormInfo
*
info
,
T
*
y
,
const
T
*
a
,
const
T
*
b
,
const
Tw
*
w
)
{
static_assert
(
std
::
is_same
<
T
,
fp16_t
>::
value
||
std
::
is_same
<
T
,
bf16_t
>::
value
,
"T must be fp16_t or bf16_t"
);
const
size_t
batch_size
=
info
->
shape
[
0
];
const
size_t
nhead
=
info
->
ndim
()
>
2
?
info
->
shape
[
1
]
:
1
;
const
size_t
dim
=
info
->
dim
();
const
ptrdiff_t
total_blocks
=
static_cast
<
ptrdiff_t
>
(
batch_size
*
nhead
);
#pragma omp parallel for
for
(
ptrdiff_t
block_idx
=
0
;
block_idx
<
total_blocks
;
++
block_idx
)
{
const
size_t
i
=
block_idx
/
nhead
;
// batch index
const
size_t
j
=
block_idx
%
nhead
;
// head index
const
T
*
a_ptr
=
a
+
i
*
info
->
a_strides
[
0
]
+
j
*
info
->
a_strides
[
1
];
const
T
*
b_ptr
=
b
+
i
*
info
->
b_strides
[
0
]
+
j
*
info
->
b_strides
[
1
];
T
*
y_ptr
=
y
+
i
*
info
->
y_strides
[
0
]
+
j
*
info
->
y_strides
[
1
];
// Compute sum of squares for RMS normalization
float
sum_squared
=
0.0
f
;
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
a_ptr
[
k
])
+
utils
::
cast
<
float
>
(
b_ptr
[
k
]);
sum_squared
+=
sum_val
*
sum_val
;
}
// Compute RMS: 1 / (sqrt(sum/dim + eps))
float
rms
=
1.
f
/
std
::
sqrt
(
sum_squared
/
(
float
)(
dim
)
+
info
->
epsilon
);
// Apply normalization: y = (a + b) * w * rms
for
(
size_t
k
=
0
;
k
<
dim
;
k
++
)
{
float
sum_val
=
utils
::
cast
<
float
>
(
a_ptr
[
k
])
+
utils
::
cast
<
float
>
(
b_ptr
[
k
]);
float
val
;
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
val
=
sum_val
*
w
[
k
]
*
rms
;
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
T
>::
value
||
std
::
is_same_v
<
Tw
,
fp16_t
>
||
std
::
is_same_v
<
Tw
,
bf16_t
>
)
{
val
=
sum_val
*
utils
::
cast
<
float
>
(
w
[
k
])
*
rms
;
}
else
{
std
::
abort
();
}
y_ptr
[
k
]
=
utils
::
cast
<
T
>
(
val
);
}
}
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
stream
)
const
{
if
(
_info
.
atype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
fp16_t
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
a
,
(
const
fp16_t
*
)
b
,
(
const
bf16_t
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_BF16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_BF16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
bf16_t
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
add_rmsnormHalfPrecision
(
&
_info
,
(
bf16_t
*
)
y
,
(
const
bf16_t
*
)
a
,
(
const
bf16_t
*
)
b
,
(
const
fp16_t
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
const
float
*
)
a
,
(
const
float
*
)
b
,
(
const
float
*
)
weight
));
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F64
)
{
CHECK_STATUS
(
add_rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
const
double
*
)
a
,
(
const
double
*
)
b
,
(
const
double
*
)
weight
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::add_rms_norm::cpu
src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h
0 → 100644
View file @
7d60e5b8
#ifndef __ADD_RMS_NORM_CPU_H__
#define __ADD_RMS_NORM_CPU_H__
#include "../add_rms_norm.h"
DESCRIPTOR
(
cpu
)
#endif
src/infiniop/ops/add_rms_norm/info.h
0 → 100644
View file @
7d60e5b8
#ifndef __ADD_RMS_NORM_INFO_H__
#define __ADD_RMS_NORM_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace
op
::
add_rms_norm
{
class
AddRMSNormInfo
{
AddRMSNormInfo
()
=
default
;
public:
infiniDtype_t
wtype
;
infiniDtype_t
atype
;
float
epsilon
;
std
::
vector
<
size_t
>
shape
;
std
::
vector
<
ptrdiff_t
>
y_strides
;
std
::
vector
<
ptrdiff_t
>
a_strides
;
std
::
vector
<
ptrdiff_t
>
b_strides
;
size_t
ndim
()
const
{
return
shape
.
size
();
}
size_t
dim
()
const
{
return
shape
[
ndim
()
-
1
];
}
static
utils
::
Result
<
AddRMSNormInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
)
{
auto
atype
=
y_desc
->
dtype
();
auto
wtype
=
weight_desc
->
dtype
();
// Check that all input tensors have the same dtype
if
(
a_desc
->
dtype
()
!=
atype
||
b_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
atype
==
INFINI_DTYPE_F16
||
atype
==
INFINI_DTYPE_BF16
)
{
// For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32
if
(
wtype
!=
atype
&&
wtype
!=
INFINI_DTYPE_F32
&&
wtype
!=
INFINI_DTYPE_BF16
&&
wtype
!=
INFINI_DTYPE_F16
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
atype
==
INFINI_DTYPE_F32
||
atype
==
INFINI_DTYPE_F64
)
{
// For FP32/FP64, activations and weights must be of the same type
if
(
atype
!=
wtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
const
size_t
y_ndim
=
y_desc
->
ndim
();
const
size_t
a_ndim
=
a_desc
->
ndim
();
const
size_t
b_ndim
=
b_desc
->
ndim
();
const
size_t
w_ndim
=
weight_desc
->
ndim
();
if
(
y_ndim
!=
a_ndim
||
y_ndim
!=
b_ndim
||
w_ndim
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
batch
=
1
;
size_t
nhead
=
1
;
size_t
dim
=
0
;
if
(
y_ndim
==
2
)
{
batch
=
y_desc
->
dim
(
0
);
dim
=
y_desc
->
dim
(
1
);
if
(
a_desc
->
dim
(
0
)
!=
batch
||
a_desc
->
dim
(
1
)
!=
dim
||
b_desc
->
dim
(
0
)
!=
batch
||
b_desc
->
dim
(
1
)
!=
dim
||
weight_desc
->
dim
(
0
)
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
else
if
(
y_ndim
==
3
)
{
batch
=
y_desc
->
dim
(
0
);
nhead
=
y_desc
->
dim
(
1
);
dim
=
y_desc
->
dim
(
2
);
if
(
a_desc
->
dim
(
0
)
!=
batch
||
a_desc
->
dim
(
1
)
!=
nhead
||
a_desc
->
dim
(
2
)
!=
dim
||
b_desc
->
dim
(
0
)
!=
batch
||
b_desc
->
dim
(
1
)
!=
nhead
||
b_desc
->
dim
(
2
)
!=
dim
||
weight_desc
->
dim
(
0
)
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
}
else
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
// Check contiguity of the last dimension
if
(
y_desc
->
stride
(
y_ndim
-
1
)
!=
1
||
a_desc
->
stride
(
a_ndim
-
1
)
!=
1
||
b_desc
->
stride
(
b_ndim
-
1
)
!=
1
||
weight_desc
->
stride
(
w_ndim
-
1
)
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
AddRMSNormInfo
info
;
info
.
wtype
=
wtype
;
info
.
atype
=
atype
;
info
.
epsilon
=
epsilon
;
info
.
shape
=
y_desc
->
shape
();
info
.
y_strides
=
y_desc
->
strides
();
info
.
a_strides
=
a_desc
->
strides
();
info
.
b_strides
=
b_desc
->
strides
();
return
utils
::
Result
<
AddRMSNormInfo
>
(
info
);
}
};
}
// namespace op::add_rms_norm
#endif // __ADD_RMS_NORM_INFO_H__
src/infiniop/ops/add_rms_norm/operator.cc
0 → 100644
View file @
7d60e5b8
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/add_rms_norm.h"
#ifdef ENABLE_CPU_API
#include "cpu/add_rms_norm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
// TODO: Add NVIDIA implementation
// #include "nvidia/add_rms_norm_nvidia.cuh"
#endif
#ifdef ENABLE_ASCEND_API
// TODO: Add Ascend implementation
// #include "ascend/add_rms_norm_aclnn.h"
#endif
#ifdef ENABLE_CAMBRICON_API
// TODO: Add Cambricon implementation
// #include "bang/add_rms_norm_bang.h"
#endif
#ifdef ENABLE_METAX_API
// TODO: Add Metax implementation
// #include "metax/add_rms_norm_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
// TODO: Add Moore implementation
// #include "moore/add_rms_norm_moore.h"
#endif
#ifdef ENABLE_KUNLUN_API
// TODO: Add Kunlun implementation
// #include "kunlun/add_rms_norm_kunlun.h"
#endif
__C
infiniStatus_t
infiniopCreateAddRMSNormDescriptor
(
infiniopHandle_t
handle
,
infiniopAddRMSNormDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
weight_desc
,
float
epsilon
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::add_rms_norm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
a_desc, \
b_desc, \
weight_desc, \
epsilon)
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
// CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetAddRMSNormWorkspaceSize
(
infiniopAddRMSNormDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
// GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// GET(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopAddRMSNorm
(
infiniopAddRMSNormDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
a
,
const
void
*
b
,
const
void
*
weight
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, a, b, weight, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
// CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroyAddRMSNormDescriptor
(
infiniopAddRMSNormDescriptor_t
desc
)
{
if
(
desc
==
nullptr
)
{
return
INFINI_STATUS_SUCCESS
;
}
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
DESTROY
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
// DESTROY(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// DESTROY(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// DESTROY(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DESTROY
}
test/infinicore/ops/add_rms_norm.py
0 → 100644
View file @
7d60e5b8
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
))
import
torch
import
infinicore
from
framework
import
(
BaseOperatorTest
,
TensorSpec
,
TestCase
,
GenericTestRunner
,
is_broadcast
,
)
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (y_shape, a_shape, b_shape, w_shape, y_strides, a_strides, b_strides)
_TEST_CASES_DATA
=
[
# Basic cases
((
1
,
4
),
(
1
,
4
),
(
1
,
4
),
(
4
,),
None
,
None
,
None
),
((
2
,
4
),
(
2
,
4
),
(
2
,
4
),
(
4
,),
None
,
None
,
None
),
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
None
,
None
,
None
),
# Strided cases
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
(
12
,
8
,
1
),
(
12
,
8
,
1
),
(
12
,
8
,
1
)),
# Large tensors
((
16
,
2048
),
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
None
),
((
16
,
2048
),
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
(
4096
,
1
)),
((
15
,
3584
),
(
15
,
3584
),
(
15
,
3584
),
(
3584
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
)),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
)),
]
# Tolerance configuration
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
2e-3
,
"rtol"
:
2e-3
},
infinicore
.
bfloat16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
infinicore
.
float32
:
{
"atol"
:
1e-5
,
"rtol"
:
1e-4
},
}
# Data types for individual tensors
_INPUT_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
]
_WEIGHT_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
# EPSILON constant for AddRMSNorm
_EPSILON
=
1e-5
def
parse_test_cases
():
"""
Parse AddRMSNorm test case data and return list of TestCase objects.
Format: (y_shape, a_shape, b_shape, w_shape, y_strides, a_strides, b_strides)
"""
test_cases
=
[]
for
data
in
_TEST_CASES_DATA
:
y_shape
=
data
[
0
]
# Output shape
a_shape
=
data
[
1
]
# First input shape
b_shape
=
data
[
2
]
# Second input shape
w_shape
=
data
[
3
]
# Weight shape (1D)
y_strides
=
data
[
4
]
if
len
(
data
)
>
4
else
None
a_strides
=
data
[
5
]
if
len
(
data
)
>
5
else
None
b_strides
=
data
[
6
]
if
len
(
data
)
>
6
else
None
# Check if tensors support in-place operations
a_supports_inplace
=
not
is_broadcast
(
a_strides
)
b_supports_inplace
=
not
is_broadcast
(
b_strides
)
y_supports_inplace
=
not
is_broadcast
(
y_strides
)
# Generate test cases for all dtype combinations
for
input_dtype
in
_INPUT_DTYPES
:
for
weight_dtype
in
_WEIGHT_DTYPES
:
# Use input dtype tolerance for output
tolerance
=
_TOLERANCE_MAP
.
get
(
input_dtype
,
{
"atol"
:
1e-5
,
"rtol"
:
1e-4
}
)
# Create typed tensor specs
a_spec
=
TensorSpec
.
from_tensor
(
a_shape
,
a_strides
,
input_dtype
)
b_spec
=
TensorSpec
.
from_tensor
(
b_shape
,
b_strides
,
input_dtype
)
w_spec
=
TensorSpec
.
from_tensor
(
w_shape
,
None
,
weight_dtype
)
# Weight is always contiguous
y_spec
=
TensorSpec
.
from_tensor
(
y_shape
,
y_strides
,
input_dtype
)
# Test Case 1: Out-of-place (return value)
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"epsilon"
:
_EPSILON
},
output_spec
=
None
,
comparison_target
=
None
,
tolerance
=
tolerance
,
description
=
f
"AddRMSNorm - OUT_OF_PLACE"
,
)
)
# Test Case 2: In-place with explicit output tensor (add_rms_norm(a, b, w, out=y))
if
y_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"epsilon"
:
_EPSILON
},
output_spec
=
y_spec
,
# Specify the output tensor spec
comparison_target
=
"out"
,
tolerance
=
tolerance
,
description
=
f
"AddRMSNorm - INPLACE(out)"
,
)
)
# Test Case 3: In-place on first input (add_rms_norm(a, b, w, out=a))
if
a_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"out"
:
0
,
"epsilon"
:
_EPSILON
,
},
# Use index 0 for first input
output_spec
=
None
,
comparison_target
=
0
,
# Compare first input
tolerance
=
tolerance
,
description
=
f
"AddRMSNorm - INPLACE(a)"
,
)
)
# Test Case 4: In-place on second input (add_rms_norm(a, b, w, out=b))
if
b_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
,
w_spec
],
kwargs
=
{
"out"
:
1
,
"epsilon"
:
_EPSILON
,
},
# Use index 1 for second input
output_spec
=
None
,
comparison_target
=
1
,
# Compare second input
tolerance
=
tolerance
,
description
=
f
"AddRMSNorm - INPLACE(b)"
,
)
)
return
test_cases
class
OpTest
(
BaseOperatorTest
):
"""AddRMSNorm operator test with simplified implementation"""
def
__init__
(
self
):
super
().
__init__
(
"AddRMSNorm"
)
def
get_test_cases
(
self
):
return
parse_test_cases
()
def
torch_operator
(
self
,
a
,
b
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
**
kwargs
):
"""PyTorch AddRMSNorm implementation"""
input_dtype
=
a
.
dtype
# Compute add(a, b)
sum_tensor
=
a
.
to
(
torch
.
float32
)
+
b
.
to
(
torch
.
float32
)
weight_fp32
=
weight
.
to
(
torch
.
float32
)
# Calculate RMSNorm: (a + b) * weight / sqrt(mean((a+b)^2) + epsilon)
variance
=
sum_tensor
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
result
=
sum_tensor
*
torch
.
rsqrt
(
variance
+
epsilon
)
*
weight_fp32
# Convert back to original dtype
result
=
result
.
to
(
input_dtype
)
if
out
is
not
None
:
out
.
copy_
(
result
)
return
out
return
result
def
infinicore_operator
(
self
,
a
,
b
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
**
kwargs
):
"""InfiniCore AddRMSNorm implementation"""
return
infinicore
.
add_rms_norm
(
a
,
b
,
weight
,
epsilon
,
out
=
out
)
def
main
():
"""Main entry point"""
runner
=
GenericTestRunner
(
OpTest
)
runner
.
run_and_exit
()
if
__name__
==
"__main__"
:
main
()
test/infiniop/add_rms_norm.py
0 → 100644
View file @
7d60e5b8
import
torch
import
ctypes
from
ctypes
import
c_uint64
from
libinfiniop
import
(
LIBINFINIOP
,
TestTensor
,
get_test_devices
,
check_error
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
TestWorkspace
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_
=
[
# y_shape, a_shape, b_shape, w_shape, y_stride, a_stride, b_stride
((
1
,
4
),
(
1
,
4
),
(
1
,
4
),
(
4
,),
None
,
None
,
None
),
((
2
,
4
),
(
2
,
4
),
(
2
,
4
),
(
4
,),
None
,
None
,
None
),
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
None
,
None
,
None
),
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
(
12
,
8
,
1
),
(
12
,
8
,
1
),
(
12
,
8
,
1
)),
((
16
,
2048
),
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
None
),
((
16
,
2048
),
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
(
4096
,
1
)),
((
15
,
3584
),
(
15
,
3584
),
(
15
,
3584
),
(
3584
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
None
,
None
,
None
),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
),
(
2048
,
8192
,
1
)),
((
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
4
,
4
,
2048
),
(
2048
,),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
),
(
16384
,
4096
,
1
)),
((
15
,
3584
),
(
15
,
3584
),
(
15
,
3584
),
(
3584
,),
None
,
None
,
None
),
((
15
,
8192
),
(
15
,
8192
),
(
15
,
8192
),
(
8192
,),
None
,
None
,
None
),
]
# w (weight) types
# Note: 'None' means the same as input dtype
_WEIGHT_DTYPES
=
[
None
,
InfiniDtype
.
F32
,
InfiniDtype
.
F16
,
InfiniDtype
.
BF16
]
# a, b types used for testing
_TENSOR_DTYPES
=
[
InfiniDtype
.
F16
,
InfiniDtype
.
BF16
]
# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_
_TEST_CASES
=
[
test_case
+
(
w_dtype
,)
for
test_case
in
_TEST_CASES_
for
w_dtype
in
_WEIGHT_DTYPES
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
2e-3
,
"rtol"
:
2e-3
},
InfiniDtype
.
BF16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
def
add_rms_norm
(
ans
,
a
,
b
,
w
,
eps
):
input_dtype
=
a
.
dtype
# Compute add(a, b)
sum_tensor
=
a
.
to
(
torch
.
float32
)
+
b
.
to
(
torch
.
float32
)
# Compute RMS normalization
scale
=
sum_tensor
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
).
add_
(
eps
).
rsqrt_
()
ans
.
set_
((
sum_tensor
.
mul_
(
scale
).
mul_
(
w
.
to
(
torch
.
float32
))).
to
(
input_dtype
))
def
test
(
handle
,
device
,
y_shape
,
a_shape
,
b_shape
,
w_shape
,
y_stride
,
a_stride
,
b_stride
,
w_dtype
=
InfiniDtype
.
F32
,
dtype
=
InfiniDtype
.
F16
,
sync
=
None
,
):
w_dtype
=
w_dtype
if
w_dtype
else
dtype
print
(
f
"Testing AddRMSNorm on
{
InfiniDeviceNames
[
device
]
}
with y_shape:
{
y_shape
}
a_shape:
{
a_shape
}
b_shape:
{
b_shape
}
w_shape:
{
w_shape
}
"
f
" y_stride:
{
y_stride
}
a_stride:
{
a_stride
}
b_stride:
{
b_stride
}
w_dtype:
{
InfiniDtypeNames
[
w_dtype
]
}
dtype:
{
InfiniDtypeNames
[
dtype
]
}
"
)
y
=
TestTensor
(
y_shape
,
y_stride
,
dtype
,
device
,
mode
=
"ones"
)
a
=
TestTensor
(
a_shape
,
a_stride
,
dtype
,
device
,
scale
=
0.01
)
b
=
TestTensor
(
b_shape
,
b_stride
,
dtype
,
device
,
scale
=
0.01
)
w
=
TestTensor
(
w_shape
,
None
,
w_dtype
,
device
)
eps
=
1e-6
add_rms_norm
(
y
.
torch_tensor
(),
a
.
torch_tensor
(),
b
.
torch_tensor
(),
w
.
torch_tensor
(),
eps
)
if
sync
is
not
None
:
sync
()
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
LIBINFINIOP
.
infiniopCreateAddRMSNormDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
y
.
descriptor
,
a
.
descriptor
,
b
.
descriptor
,
w
.
descriptor
,
eps
,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for
tensor
in
[
a
,
b
,
y
,
w
]:
tensor
.
destroy_desc
()
workspace_size
=
c_uint64
(
0
)
check_error
(
LIBINFINIOP
.
infiniopGetAddRMSNormWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
y
.
device
)
def
lib_add_rms_norm
():
check_error
(
LIBINFINIOP
.
infiniopAddRMSNorm
(
descriptor
,
workspace
.
data
(),
workspace_size
.
value
,
y
.
data
(),
a
.
data
(),
b
.
data
(),
w
.
data
(),
None
,
)
)
lib_add_rms_norm
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
y
.
actual_tensor
(),
y
.
torch_tensor
(),
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
y
.
actual_tensor
(),
y
.
torch_tensor
(),
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
add_rms_norm
(
y
.
torch_tensor
(),
a
.
torch_tensor
(),
b
.
torch_tensor
(),
w
.
torch_tensor
(),
eps
),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_add_rms_norm
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
LIBINFINIOP
.
infiniopDestroyAddRMSNormDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
args
=
get_args
()
# Configure testing options
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
# Execute tests
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/libinfiniop/op_register.py
View file @
7d60e5b8
...
@@ -383,6 +383,43 @@ def rms_norm_(lib):
...
@@ -383,6 +383,43 @@ def rms_norm_(lib):
]
]
@
OpRegister
.
operator
def
add_rms_norm_
(
lib
):
lib
.
infiniopCreateAddRMSNormDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateAddRMSNormDescriptor
.
argtypes
=
[
infiniopHandle_t
,
POINTER
(
infiniopOperatorDescriptor_t
),
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
c_float
,
]
lib
.
infiniopGetAddRMSNormWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetAddRMSNormWorkspaceSize
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
POINTER
(
c_size_t
),
]
lib
.
infiniopAddRMSNorm
.
restype
=
c_int32
lib
.
infiniopAddRMSNorm
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
c_void_p
,
c_size_t
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
lib
.
infiniopDestroyAddRMSNormDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyAddRMSNormDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
@
OpRegister
.
operator
@
OpRegister
.
operator
def
rope_
(
lib
):
def
rope_
(
lib
):
lib
.
infiniopCreateRoPEDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateRoPEDescriptor
.
restype
=
c_int32
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment