Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
67d81412
Commit
67d81412
authored
Mar 13, 2025
by
PanZezhong
Browse files
issue/5 添加reduce类通用代码,实现rms norm cpu算子
parent
fd0242ed
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
462 additions
and
95 deletions
+462
-95
.github/workflows/build.yml
.github/workflows/build.yml
+1
-0
include/infiniop/ops/rms_norm.h
include/infiniop/ops/rms_norm.h
+1
-1
src/infiniop/devices/cpu/common_cpu.cc
src/infiniop/devices/cpu/common_cpu.cc
+0
-59
src/infiniop/devices/cpu/common_cpu.h
src/infiniop/devices/cpu/common_cpu.h
+4
-5
src/infiniop/ops/matmul/cpu/matmul_cpu.cc
src/infiniop/ops/matmul/cpu/matmul_cpu.cc
+6
-7
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc
+92
-0
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.h
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.h
+7
-0
src/infiniop/ops/rms_norm/operator.cc
src/infiniop/ops/rms_norm/operator.cc
+53
-12
src/infiniop/ops/rms_norm/rms_norm.h
src/infiniop/ops/rms_norm/rms_norm.h
+94
-0
src/infiniop/reduce/cpu/reduce.h
src/infiniop/reduce/cpu/reduce.h
+73
-0
src/utils.h
src/utils.h
+1
-0
src/utils/custom_types.cc
src/utils/custom_types.cc
+63
-0
src/utils/custom_types.h
src/utils/custom_types.h
+40
-0
src/utils/rearrange.cc
src/utils/rearrange.cc
+6
-1
test/infiniop/rms_norm.py
test/infiniop/rms_norm.py
+4
-3
xmake.lua
xmake.lua
+5
-2
xmake/cpu.lua
xmake/cpu.lua
+10
-5
xmake/cuda.lua
xmake/cuda.lua
+2
-0
No files found.
.github/workflows/build.yml
View file @
67d81412
...
@@ -46,3 +46,4 @@ jobs:
...
@@ -46,3 +46,4 @@ jobs:
run
:
|
run
:
|
pip install torch
pip install torch
LD_LIBRARY_PATH=$HOME/.infini/lib python test/infiniop/matmul.py --cpu
LD_LIBRARY_PATH=$HOME/.infini/lib python test/infiniop/matmul.py --cpu
LD_LIBRARY_PATH=$HOME/.infini/lib python test/infiniop/rms_norm.py --cpu
include/infiniop/ops/rms_norm.h
View file @
67d81412
...
@@ -16,7 +16,7 @@ __C __export infiniStatus_t infiniopCreateRMSNormDescriptor(
...
@@ -16,7 +16,7 @@ __C __export infiniStatus_t infiniopCreateRMSNormDescriptor(
__C
__export
infiniStatus_t
infiniopGetRMSNormWorkspaceSize
(
infiniopRMSNormDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopGetRMSNormWorkspaceSize
(
infiniopRMSNormDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopRMSNorm
(
infiniopRMSNormDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
__C
__export
infiniStatus_t
infiniopRMSNorm
(
infiniopRMSNormDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
void
const
*
x
,
void
const
*
w
,
void
*
stream
);
void
*
y
,
const
void
*
x
,
const
void
*
w
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyRMSNormDescriptor
(
infiniopRMSNormDescriptor_t
desc
);
__C
__export
infiniStatus_t
infiniopDestroyRMSNormDescriptor
(
infiniopRMSNormDescriptor_t
desc
);
...
...
src/infiniop/devices/cpu/common_cpu.cc
View file @
67d81412
#include "common_cpu.h"
#include "common_cpu.h"
float
f16_to_f32
(
uint16_t
h
)
{
uint32_t
sign
=
(
h
&
0x8000
)
<<
16
;
int32_t
exponent
=
(
h
>>
10
)
&
0x1F
;
uint32_t
mantissa
=
h
&
0x3FF
;
uint32_t
f32
;
if
(
exponent
==
31
)
{
if
(
mantissa
!=
0
)
{
f32
=
sign
|
0x7F800000
|
(
mantissa
<<
13
);
}
else
{
f32
=
sign
|
0x7F800000
;
}
}
else
if
(
exponent
==
0
)
{
if
(
mantissa
==
0
)
{
f32
=
sign
;
}
else
{
exponent
=
-
14
;
while
((
mantissa
&
0x400
)
==
0
)
{
mantissa
<<=
1
;
exponent
--
;
}
mantissa
&=
0x3FF
;
f32
=
sign
|
((
exponent
+
127
)
<<
23
)
|
(
mantissa
<<
13
);
}
}
else
{
f32
=
sign
|
((
exponent
+
127
-
15
)
<<
23
)
|
(
mantissa
<<
13
);
}
float
result
;
memcpy
(
&
result
,
&
f32
,
sizeof
(
result
));
return
result
;
}
uint16_t
f32_to_f16
(
float
val
)
{
uint32_t
f32
;
memcpy
(
&
f32
,
&
val
,
sizeof
(
f32
));
// Read the bits of the float32
uint16_t
sign
=
(
f32
>>
16
)
&
0x8000
;
// Extract the sign bit
int32_t
exponent
=
((
f32
>>
23
)
&
0xFF
)
-
127
;
// Extract and de-bias the exponent
uint32_t
mantissa
=
f32
&
0x7FFFFF
;
// Extract the mantissa (fraction part)
if
(
exponent
>=
31
)
{
// Special cases for Inf and NaN
// NaN
if
(
exponent
==
128
&&
mantissa
!=
0
)
{
return
sign
|
0x7E00
;
}
// Infinity
return
sign
|
0x7C00
;
}
else
if
(
exponent
>=
-
14
)
{
// Normalized case
return
(
uint16_t
)(
sign
|
((
exponent
+
15
)
<<
10
)
|
(
mantissa
>>
13
));
}
else
if
(
exponent
>=
-
24
)
{
mantissa
|=
0x800000
;
// Add implicit leading 1
mantissa
>>=
(
-
14
-
exponent
);
return
(
uint16_t
)(
sign
|
(
mantissa
>>
13
));
}
else
{
// Too small for subnormal: return signed zero
return
(
uint16_t
)
sign
;
}
}
size_t
indexToReducedOffset
(
size_t
indexToReducedOffset
(
size_t
flat_index
,
size_t
flat_index
,
size_t
ndim
,
size_t
ndim
,
...
...
src/infiniop/devices/cpu/common_cpu.h
View file @
67d81412
...
@@ -2,17 +2,16 @@
...
@@ -2,17 +2,16 @@
#define __INFINIOP_COMMON_CPU_H__
#define __INFINIOP_COMMON_CPU_H__
#include "../../../utils.h"
#include "../../../utils.h"
#include "cpu_handle.h"
#include <cmath>
#include <cmath>
#include <cstddef>
#include <cstddef>
#include <cstdint>
#include <cstdint>
#include <cstring>
#include <cstring>
#include <vector>
#include <vector>
// convert half-precision float to single-precision float
#ifdef ENABLE_OMP
float
f16_to_f32
(
uint16_t
code
);
#include <omp.h>
#endif
// convert single-precision float to half-precision float
uint16_t
f32_to_f16
(
float
val
);
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
size_t
indexToReducedOffset
(
size_t
flat_index
,
size_t
ndim
,
const
ptrdiff_t
*
broadcasted_strides
,
const
ptrdiff_t
*
target_strides
);
size_t
indexToReducedOffset
(
size_t
flat_index
,
size_t
ndim
,
const
ptrdiff_t
*
broadcasted_strides
,
const
ptrdiff_t
*
target_strides
);
...
...
src/infiniop/ops/matmul/cpu/matmul_cpu.cc
View file @
67d81412
#include "matmul_cpu.h"
#include "matmul_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../devices/cpu/cpu_handle.h"
namespace
op
::
matmul
::
cpu
{
namespace
op
::
matmul
::
cpu
{
...
@@ -52,17 +51,17 @@ void calculate(
...
@@ -52,17 +51,17 @@ void calculate(
for
(
size_t
k_
=
0
;
k_
<
info
.
k
;
++
k_
)
{
for
(
size_t
k_
=
0
;
k_
<
info
.
k
;
++
k_
)
{
auto
a_
=
reinterpret_cast
<
const
Tdata
*>
(
a
)
+
i
*
info
.
a_matrix
.
stride
+
m_
*
info
.
a_matrix
.
row_stride
+
k_
*
info
.
a_matrix
.
col_stride
;
auto
a_
=
reinterpret_cast
<
const
Tdata
*>
(
a
)
+
i
*
info
.
a_matrix
.
stride
+
m_
*
info
.
a_matrix
.
row_stride
+
k_
*
info
.
a_matrix
.
col_stride
;
auto
b_
=
reinterpret_cast
<
const
Tdata
*>
(
b
)
+
i
*
info
.
b_matrix
.
stride
+
n_
*
info
.
b_matrix
.
col_stride
+
k_
*
info
.
b_matrix
.
row_stride
;
auto
b_
=
reinterpret_cast
<
const
Tdata
*>
(
b
)
+
i
*
info
.
b_matrix
.
stride
+
n_
*
info
.
b_matrix
.
col_stride
+
k_
*
info
.
b_matrix
.
row_stride
;
if
constexpr
(
std
::
is_same
<
Tdata
,
uint
16_t
>::
value
)
{
if
constexpr
(
std
::
is_same
<
Tdata
,
fp
16_t
>::
value
)
{
sum
+=
f16_to_f32
(
*
a_
)
*
f16_to_f32
(
*
b_
);
sum
+=
utils
::
cast
<
float
>
(
*
a_
)
*
utils
::
cast
<
float
>
(
*
b_
);
}
else
{
}
else
{
sum
+=
*
a_
*
(
*
b_
);
sum
+=
*
a_
*
(
*
b_
);
}
}
}
}
if
constexpr
(
std
::
is_same
<
Tdata
,
uint
16_t
>::
value
)
{
if
constexpr
(
std
::
is_same
<
Tdata
,
fp
16_t
>::
value
)
{
if
(
beta
==
0
)
{
if
(
beta
==
0
)
{
*
c_
=
f32_to_f16
(
alpha
*
sum
);
*
c_
=
utils
::
cast
<
fp16_t
>
(
alpha
*
sum
);
}
else
{
}
else
{
*
c_
=
f32_to_f16
(
beta
*
f16_to_f32
(
*
c_
)
+
alpha
*
sum
);
*
c_
=
utils
::
cast
<
fp16_t
>
(
beta
*
utils
::
cast
<
float
>
(
*
c_
)
+
alpha
*
sum
);
}
}
}
else
{
}
else
{
*
c_
=
beta
*
(
*
c_
)
+
alpha
*
sum
;
*
c_
=
beta
*
(
*
c_
)
+
alpha
*
sum
;
...
@@ -84,7 +83,7 @@ infiniStatus_t Descriptor::calculate(
...
@@ -84,7 +83,7 @@ infiniStatus_t Descriptor::calculate(
switch
(
_dtype
)
{
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
case
INFINI_DTYPE_F16
:
cpu
::
calculate
<
uint
16_t
>
(
_info
,
c
,
beta
,
a
,
b
,
alpha
);
cpu
::
calculate
<
fp
16_t
>
(
_info
,
c
,
beta
,
a
,
b
,
alpha
);
return
INFINI_STATUS_SUCCESS
;
return
INFINI_STATUS_SUCCESS
;
case
INFINI_DTYPE_F32
:
case
INFINI_DTYPE_F32
:
...
...
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.cc
0 → 100644
View file @
67d81412
#include "rms_norm_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../../../reduce/cpu/reduce.h"
namespace
op
::
rms_norm
::
cpu
{
Descriptor
::~
Descriptor
()
{}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
w_desc
,
float
epsilon
)
{
RMSNormInfo
info
;
CHECK_STATUS
(
createRMSNormInfo
(
&
info
,
y_desc
,
x_desc
,
w_desc
,
epsilon
));
*
desc_ptr
=
new
Descriptor
(
nullptr
,
info
,
0
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
T
>
infiniStatus_t
rmsnorm
(
const
RMSNormInfo
*
info
,
T
*
y
,
const
T
*
x
,
const
T
*
w
)
{
#pragma omp parallel for
for
(
ptrdiff_t
i
=
0
;
i
<
ptrdiff_t
(
info
->
shape
[
0
]);
i
++
)
{
T
*
x_
=
(
T
*
)(
x
+
i
*
info
->
x_strides
[
0
]);
T
*
y_
=
(
T
*
)(
y
+
i
*
info
->
y_strides
[
0
]);
// [Reduce] sum of x^2 on last dimension
T
ss
=
op
::
common_cpu
::
reduce_op
::
sumSquared
(
x_
,
info
->
shape
[
1
],
info
->
x_strides
[
1
]);
// 1 / (sqrt(sum/dim + eps))
T
rms
=
(
T
)
1
/
std
::
sqrt
(
ss
/
(
T
)(
info
->
shape
[
1
])
+
(
T
)(
info
->
epsilon
));
for
(
size_t
j
=
0
;
j
<
info
->
shape
[
1
];
j
++
)
{
y_
[
j
*
info
->
y_strides
[
1
]]
=
x_
[
j
*
info
->
x_strides
[
1
]]
*
w
[
j
]
*
rms
;
}
}
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
Tw
>
infiniStatus_t
rmsnormF16
(
const
RMSNormInfo
*
info
,
fp16_t
*
y
,
const
fp16_t
*
x
,
const
Tw
*
w
)
{
#pragma omp parallel for
for
(
ptrdiff_t
i
=
0
;
i
<
ptrdiff_t
(
info
->
shape
[
0
]);
i
++
)
{
fp16_t
*
x_
=
(
fp16_t
*
)(
x
+
i
*
info
->
x_strides
[
0
]);
fp16_t
*
y_
=
(
fp16_t
*
)(
y
+
i
*
info
->
y_strides
[
0
]);
// [Reduce] sum of x^2 on last dimension
float
ss
=
op
::
common_cpu
::
reduce_op
::
sumSquared
(
x_
,
info
->
shape
[
1
],
info
->
x_strides
[
1
]);
// 1 / (sqrt(sum/dim + eps))
float
rms
=
1.
f
/
std
::
sqrt
(
ss
/
(
float
)(
info
->
shape
[
1
])
+
info
->
epsilon
);
for
(
size_t
j
=
0
;
j
<
info
->
shape
[
1
];
j
++
)
{
if
constexpr
(
std
::
is_same
<
Tw
,
float
>::
value
)
{
float
val
=
utils
::
cast
<
float
>
(
x_
[
j
*
info
->
x_strides
[
1
]])
*
w
[
j
]
*
rms
;
y_
[
j
*
info
->
y_strides
[
1
]]
=
utils
::
cast
<
fp16_t
>
(
val
);
}
else
if
constexpr
(
std
::
is_same
<
Tw
,
fp16_t
>::
value
)
{
float
val
=
utils
::
cast
<
float
>
(
x_
[
j
*
info
->
x_strides
[
1
]])
*
utils
::
cast
<
float
>
(
w
[
j
])
*
rms
;
y_
[
j
*
info
->
y_strides
[
1
]]
=
utils
::
cast
<
fp16_t
>
(
val
);
}
else
{
std
::
abort
();
}
}
}
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
const
void
*
w
,
void
*
stream
)
{
if
(
_info
.
atype
==
INFINI_DTYPE_F16
)
{
if
(
_info
.
wtype
==
INFINI_DTYPE_F16
)
{
CHECK_STATUS
(
rmsnormF16
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
x
,
(
const
fp16_t
*
)
w
));
}
else
if
(
_info
.
wtype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
rmsnormF16
(
&
_info
,
(
fp16_t
*
)
y
,
(
const
fp16_t
*
)
x
,
(
const
float
*
)
w
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F32
)
{
CHECK_STATUS
(
rmsnorm
(
&
_info
,
(
float
*
)
y
,
(
float
*
)
x
,
(
float
*
)
w
));
}
else
if
(
_info
.
atype
==
INFINI_DTYPE_F64
)
{
CHECK_STATUS
(
rmsnorm
(
&
_info
,
(
double
*
)
y
,
(
double
*
)
x
,
(
double
*
)
w
));
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::rms_norm::cpu
src/infiniop/ops/rms_norm/cpu/rms_norm_cpu.h
0 → 100644
View file @
67d81412
#ifndef __RMS_NORM_CPU_H__
#define __RMS_NORM_CPU_H__
#include "../rms_norm.h"
DESCRIPTOR
(
cpu
)
#endif
src/infiniop/ops/rms_norm/operator.cc
View file @
67d81412
...
@@ -2,6 +2,10 @@
...
@@ -2,6 +2,10 @@
#include "../../handle.h"
#include "../../handle.h"
#include "infiniop/ops/rms_norm.h"
#include "infiniop/ops/rms_norm.h"
#ifdef ENABLE_CPU_API
#include "cpu/rms_norm_cpu.h"
#endif
__C
infiniStatus_t
infiniopCreateRMSNormDescriptor
(
__C
infiniStatus_t
infiniopCreateRMSNormDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
infiniopRMSNormDescriptor_t
*
desc_ptr
,
infiniopRMSNormDescriptor_t
*
desc_ptr
,
...
@@ -9,10 +13,20 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
...
@@ -9,10 +13,20 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
w_desc
,
infiniopTensorDescriptor_t
w_desc
,
float
epsilon
)
{
float
epsilon
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::rms_norm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
x_desc, \
w_desc, \
epsilon);
switch
(
handle
->
device
)
{
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU
#ifdef ENABLE_CPU_API
case
DevCpu
:
CREATE
(
INFINI_DEVICE_CPU
,
cpu
)
return
cpuCreateRMSNormDescriptor
(
handle
,
(
RMSNormCpuDescriptor_t
*
)
desc_ptr
,
y_desc
,
x_desc
,
w_desc
,
epsilon
);
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_NV_GPU
case
DevNvGpu
:
{
case
DevNvGpu
:
{
...
@@ -45,14 +59,22 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
...
@@ -45,14 +59,22 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
}
}
#endif
#endif
}
}
#undef CREATE
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopGetRMSNormWorkspaceSize
(
infiniopRMSNormDescriptor_t
desc
,
size_t
*
size
)
{
__C
infiniStatus_t
infiniopGetRMSNormWorkspaceSize
(
infiniopRMSNormDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::rms_norm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU
#ifdef ENABLE_CPU_API
case
DevCpu
:
GET
(
INFINI_DEVICE_CPU
,
cpu
)
return
cpuGetRMSNormWorkspaceSize
((
RMSNormCpuDescriptor_t
)
desc
,
size
);
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_NV_GPU
case
DevNvGpu
:
{
case
DevNvGpu
:
{
...
@@ -82,15 +104,23 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
...
@@ -82,15 +104,23 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
}
}
#endif
#endif
}
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopRMSNorm
(
infiniopRMSNormDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
__C
infiniStatus_t
infiniopRMSNorm
(
infiniopRMSNormDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
const
void
*
w
,
void
*
stream
)
{
void
*
y
,
const
void
*
x
,
const
void
*
w
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::rms_norm::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, y, x, w, stream);
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU
#ifdef ENABLE_CPU_API
case
DevCpu
:
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
)
return
cpuRMSNorm
((
RMSNormCpuDescriptor_t
)
desc
,
workspace
,
workspace_size
,
y
,
x
,
w
,
stream
);
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_NV_GPU
case
DevNvGpu
:
{
case
DevNvGpu
:
{
...
@@ -125,14 +155,22 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
...
@@ -125,14 +155,22 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
}
}
#endif
#endif
}
}
#undef CALCULATE
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
__C
infiniStatus_t
infiniopDestroyRMSNormDescriptor
(
infiniopRMSNormDescriptor_t
desc
)
{
__C
infiniStatus_t
infiniopDestroyRMSNormDescriptor
(
infiniopRMSNormDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::rms_norm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU
#ifdef ENABLE_CPU_API
case
DevCpu
:
DESTROY
(
INFINI_DEVICE_CPU
,
cpu
)
return
cpuDestroyRMSNormDescriptor
((
RMSNormCpuDescriptor_t
)
desc
);
#endif
#endif
#ifdef ENABLE_NV_GPU
#ifdef ENABLE_NV_GPU
case
DevNvGpu
:
{
case
DevNvGpu
:
{
...
@@ -161,5 +199,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
...
@@ -161,5 +199,8 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
}
}
#endif
#endif
}
}
#undef DESTROY
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
}
src/infiniop/ops/rms_norm/rms_norm.h
0 → 100644
View file @
67d81412
#ifndef RMS_NORM_H
#define RMS_NORM_H
#include "../../operator.h"
#include "../../tensor.h"
#include <vector>
struct
RMSNormInfo
{
infiniDtype_t
wtype
;
infiniDtype_t
atype
;
float
epsilon
;
std
::
vector
<
size_t
>
shape
;
std
::
vector
<
ptrdiff_t
>
y_strides
;
std
::
vector
<
ptrdiff_t
>
x_strides
;
size_t
ndim
()
{
return
shape
.
size
();
}
size_t
dim
()
{
return
shape
[
ndim
()
-
1
];
}
};
inline
infiniStatus_t
createRMSNormInfo
(
RMSNormInfo
*
info
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
w_desc
,
float
epsilon
)
{
auto
atype
=
y_desc
->
dtype
();
auto
wtype
=
w_desc
->
dtype
();
if
(
x_desc
->
dtype
()
!=
atype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
if
(
atype
==
INFINI_DTYPE_F16
)
{
if
(
wtype
!=
INFINI_DTYPE_F16
&&
wtype
!=
INFINI_DTYPE_F32
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
atype
==
INFINI_DTYPE_F32
||
atype
==
INFINI_DTYPE_F64
)
{
if
(
atype
!=
wtype
)
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
info
->
wtype
=
wtype
;
info
->
atype
=
atype
;
info
->
epsilon
=
epsilon
;
if
(
y_desc
->
ndim
()
!=
2
||
x_desc
->
ndim
()
!=
2
||
w_desc
->
ndim
()
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
batch
=
y_desc
->
shape
()[
0
];
size_t
dim
=
y_desc
->
shape
()[
1
];
if
(
x_desc
->
shape
()[
0
]
!=
batch
||
x_desc
->
shape
()[
1
]
!=
dim
||
w_desc
->
shape
()[
0
]
!=
dim
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
info
->
shape
=
std
::
move
(
y_desc
->
shape
());
info
->
y_strides
=
std
::
move
(
y_desc
->
strides
());
info
->
x_strides
=
std
::
move
(
x_desc
->
strides
());
return
INFINI_STATUS_SUCCESS
;
}
#define DESCRIPTOR(NAMESPACE) \
namespace op::rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
RMSNormInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
RMSNormInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) : InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
size_t workspaceSize() const { return _workspace_size; } \
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc, \
infiniopTensorDescriptor_t w_desc, \
float epsilon); \
infiniStatus_t calculate(void *workspace, size_t workspace_size, \
void *y, const void *x, const void *w, void *stream); \
}; \
}
#endif // RMS_NORM_H
src/infiniop/reduce/cpu/reduce.h
0 → 100644
View file @
67d81412
#ifndef __INFINIOP_REDUCE_CPU_H__
#define __INFINIOP_REDUCE_CPU_H__
#include "../../../utils.h"
#include <cstddef>
#ifdef ENABLE_OMP
#include <omp.h>
#endif
#include <type_traits>
namespace
op
::
common_cpu
{
namespace
reduce_op
{
template
<
typename
T
>
using
ReduceToSame
=
std
::
disjunction
<
std
::
is_same
<
T
,
float
>
,
std
::
is_same
<
T
,
double
>
,
std
::
is_same
<
T
,
uint8_t
>
,
std
::
is_same
<
T
,
int8_t
>
,
std
::
is_same
<
T
,
uint16_t
>
,
std
::
is_same
<
T
,
int16_t
>
,
std
::
is_same
<
T
,
uint32_t
>
,
std
::
is_same
<
T
,
int32_t
>
,
std
::
is_same
<
T
,
uint64_t
>
,
std
::
is_same
<
T
,
int64_t
>>
;
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
ReduceToSame
<
T
>
::
value
>>
T
sum
(
const
T
*
data
,
size_t
len
,
ptrdiff_t
stride
=
1
)
{
T
result
=
0
;
for
(
size_t
i
=
0
;
i
<
len
;
i
++
)
{
result
+=
data
[
i
*
stride
];
}
return
result
;
}
float
sum
(
const
fp16_t
*
data
,
size_t
len
,
ptrdiff_t
stride
=
1
)
{
float
result
=
0
;
for
(
size_t
i
=
0
;
i
<
len
;
i
++
)
{
result
+=
utils
::
cast
<
float
>
(
data
[
i
*
stride
]);
}
return
result
;
}
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
ReduceToSame
<
T
>
::
value
>>
T
sumSquared
(
const
T
*
data
,
size_t
len
,
ptrdiff_t
stride
=
1
)
{
T
result
=
0
;
for
(
size_t
i
=
0
;
i
<
len
;
i
++
)
{
T
val
=
data
[
i
*
stride
];
result
+=
val
*
val
;
}
return
result
;
}
float
sumSquared
(
const
fp16_t
*
data
,
size_t
len
,
ptrdiff_t
stride
=
1
)
{
float
result
=
0
;
for
(
size_t
i
=
0
;
i
<
len
;
i
++
)
{
float
val
=
utils
::
cast
<
float
>
(
data
[
i
*
stride
]);
result
+=
val
*
val
;
}
return
result
;
}
}
// namespace reduce_op
}
// namespace op::common_cpu
#endif //__INFINIOP_REDUCE_CPU_H__
src/utils.h
View file @
67d81412
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "infinicore.h"
#include "infinicore.h"
#include "utils/check.h"
#include "utils/check.h"
#include "utils/custom_types.h"
#include "utils/rearrange.h"
#include "utils/rearrange.h"
inline
size_t
infiniSizeOf
(
infiniDtype_t
dtype
)
{
inline
size_t
infiniSizeOf
(
infiniDtype_t
dtype
)
{
...
...
src/utils/custom_types.cc
0 → 100644
View file @
67d81412
#include "custom_types.h"
#include <cstdint>
#include <cstring>
float
_f16_to_f32
(
fp16_t
val
)
{
uint16_t
h
=
val
.
_v
;
uint32_t
sign
=
(
h
&
0x8000
)
<<
16
;
int32_t
exponent
=
(
h
>>
10
)
&
0x1F
;
uint32_t
mantissa
=
h
&
0x3FF
;
uint32_t
f32
;
if
(
exponent
==
31
)
{
if
(
mantissa
!=
0
)
{
f32
=
sign
|
0x7F800000
|
(
mantissa
<<
13
);
}
else
{
f32
=
sign
|
0x7F800000
;
}
}
else
if
(
exponent
==
0
)
{
if
(
mantissa
==
0
)
{
f32
=
sign
;
}
else
{
exponent
=
-
14
;
while
((
mantissa
&
0x400
)
==
0
)
{
mantissa
<<=
1
;
exponent
--
;
}
mantissa
&=
0x3FF
;
f32
=
sign
|
((
exponent
+
127
)
<<
23
)
|
(
mantissa
<<
13
);
}
}
else
{
f32
=
sign
|
((
exponent
+
127
-
15
)
<<
23
)
|
(
mantissa
<<
13
);
}
float
result
;
memcpy
(
&
result
,
&
f32
,
sizeof
(
result
));
return
result
;
}
fp16_t
_f32_to_f16
(
float
val
)
{
uint32_t
f32
;
memcpy
(
&
f32
,
&
val
,
sizeof
(
f32
));
// Read the bits of the float32
uint16_t
sign
=
(
f32
>>
16
)
&
0x8000
;
// Extract the sign bit
int32_t
exponent
=
((
f32
>>
23
)
&
0xFF
)
-
127
;
// Extract and de-bias the exponent
uint32_t
mantissa
=
f32
&
0x7FFFFF
;
// Extract the mantissa (fraction part)
if
(
exponent
>=
31
)
{
// Special cases for Inf and NaN
// NaN
if
(
exponent
==
128
&&
mantissa
!=
0
)
{
return
fp16_t
{
static_cast
<
uint16_t
>
(
sign
|
0x7E00
)};
}
// Infinity
return
fp16_t
{
static_cast
<
uint16_t
>
(
sign
|
0x7C00
)};
}
else
if
(
exponent
>=
-
14
)
{
// Normalized case
return
fp16_t
{(
uint16_t
)(
sign
|
((
exponent
+
15
)
<<
10
)
|
(
mantissa
>>
13
))};
}
else
if
(
exponent
>=
-
24
)
{
mantissa
|=
0x800000
;
// Add implicit leading 1
mantissa
>>=
(
-
14
-
exponent
);
return
fp16_t
{(
uint16_t
)(
sign
|
(
mantissa
>>
13
))};
}
else
{
// Too small for subnormal: return signed zero
return
fp16_t
{(
uint16_t
)
sign
};
}
}
src/utils/custom_types.h
0 → 100644
View file @
67d81412
#ifndef __INFINIUTILS_CUSTOM_TYPES_H__
#define __INFINIUTILS_CUSTOM_TYPES_H__
#include <stdint.h>
#include <type_traits>
struct
CustomFloat16
{
uint16_t
_v
;
};
typedef
struct
CustomFloat16
fp16_t
;
struct
CustomBFloat16
{
uint16_t
_v
;
};
typedef
struct
CustomBFloat16
bf16_t
;
float
_f16_to_f32
(
fp16_t
val
);
fp16_t
_f32_to_f16
(
float
val
);
namespace
utils
{
// General template for non-fp16_t conversions
template
<
typename
TypeTo
,
typename
TypeFrom
>
TypeTo
cast
(
TypeFrom
val
)
{
if
constexpr
(
std
::
is_same
<
TypeTo
,
TypeFrom
>::
value
)
{
return
val
;
}
else
if
constexpr
(
std
::
is_same
<
TypeTo
,
fp16_t
>::
value
&&
std
::
is_same
<
TypeFrom
,
float
>::
value
)
{
return
_f32_to_f16
(
val
);
}
else
if
constexpr
(
std
::
is_same
<
TypeTo
,
fp16_t
>::
value
&&
!
std
::
is_same
<
TypeFrom
,
float
>::
value
)
{
return
_f32_to_f16
(
static_cast
<
TypeTo
>
(
val
));
}
else
if
constexpr
(
std
::
is_same
<
TypeFrom
,
fp16_t
>::
value
&&
std
::
is_same
<
TypeTo
,
float
>::
value
)
{
return
_f16_to_f32
(
val
);
}
else
if
constexpr
(
std
::
is_same
<
TypeFrom
,
fp16_t
>::
value
&&
!
std
::
is_same
<
TypeTo
,
float
>::
value
)
{
return
static_cast
<
TypeTo
>
(
_f16_to_f32
(
val
));
}
else
{
return
static_cast
<
TypeTo
>
(
val
);
}
}
}
// namespace utils
#endif
src/utils/rearrange.cc
View file @
67d81412
...
@@ -4,6 +4,10 @@
...
@@ -4,6 +4,10 @@
#include <cstring>
#include <cstring>
#include <vector>
#include <vector>
#ifdef ENABLE_OMP
#include <omp.h>
#endif
namespace
utils
{
namespace
utils
{
RearrangeMeta
::
RearrangeMeta
(
std
::
vector
<
ptrdiff_t
>
meta
)
RearrangeMeta
::
RearrangeMeta
(
std
::
vector
<
ptrdiff_t
>
meta
)
...
@@ -98,7 +102,8 @@ void RearrangeMeta::launch(void *dst_, const void *src_) const {
...
@@ -98,7 +102,8 @@ void RearrangeMeta::launch(void *dst_, const void *src_) const {
if
(
count_
==
1
)
{
if
(
count_
==
1
)
{
std
::
memcpy
(
dst_
,
src_
,
unit_
);
std
::
memcpy
(
dst_
,
src_
,
unit_
);
}
else
{
}
else
{
for
(
size_t
i
=
0
;
i
<
count_
;
++
i
)
{
#pragma omp parallel for
for
(
ptrdiff_t
i
=
0
;
i
<
(
ptrdiff_t
)
count_
;
++
i
)
{
auto
dst
=
reinterpret_cast
<
char
*>
(
dst_
);
auto
dst
=
reinterpret_cast
<
char
*>
(
dst_
);
auto
src
=
reinterpret_cast
<
const
char
*>
(
src_
);
auto
src
=
reinterpret_cast
<
const
char
*>
(
src_
);
auto
rem
=
i
;
auto
rem
=
i
;
...
...
test/infiniop/rms_norm.py
View file @
67d81412
...
@@ -25,6 +25,7 @@ from libinfiniop import (
...
@@ -25,6 +25,7 @@ from libinfiniop import (
# These are not meant to be imported from other modules
# These are not meant to be imported from other modules
_TEST_CASES
=
[
_TEST_CASES
=
[
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((
1
,
4
),
(
1
,
4
),
(
4
,),
None
,
None
,
torch
.
float32
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
torch
.
float32
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
torch
.
float32
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
torch
.
float16
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
torch
.
float16
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
torch
.
float32
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
torch
.
float32
),
...
@@ -57,7 +58,7 @@ def rms_norm(x, w, eps):
...
@@ -57,7 +58,7 @@ def rms_norm(x, w, eps):
hidden_states
=
x
.
to
(
torch
.
float32
)
hidden_states
=
x
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
eps
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
eps
)
return
w
*
hidden_states
.
to
(
input_dtype
)
return
(
w
*
hidden_states
)
.
to
(
input_dtype
)
def
test
(
def
test
(
...
@@ -79,7 +80,7 @@ def test(
...
@@ -79,7 +80,7 @@ def test(
y
=
torch
.
zeros
(
y_shape
,
dtype
=
dtype
).
to
(
torch_device
)
y
=
torch
.
zeros
(
y_shape
,
dtype
=
dtype
).
to
(
torch_device
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
dtype
).
to
(
torch_device
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
dtype
).
to
(
torch_device
)
w
=
torch
.
ones
(
w_shape
,
dtype
=
w_dtype
).
to
(
torch_device
)
w
=
torch
.
rand
(
w_shape
,
dtype
=
w_dtype
).
to
(
torch_device
)
eps
=
1e-5
eps
=
1e-5
ans
=
rms_norm
(
x
,
w
,
eps
)
ans
=
rms_norm
(
x
,
w
,
eps
)
...
@@ -106,7 +107,7 @@ def test(
...
@@ -106,7 +107,7 @@ def test(
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for
tensor
in
[
x_tensor
,
y_tensor
,
w_tensor
]:
for
tensor
in
[
x_tensor
,
y_tensor
,
w_tensor
]:
tensor
.
des
criptor
.
contents
.
invalidate
(
)
tensor
.
des
troyDesc
(
lib
)
workspace_size
=
c_uint64
(
0
)
workspace_size
=
c_uint64
(
0
)
check_error
(
check_error
(
...
...
xmake.lua
View file @
67d81412
...
@@ -8,7 +8,6 @@ add_includedirs("include")
...
@@ -8,7 +8,6 @@ add_includedirs("include")
set_encodings
(
"utf-8"
)
set_encodings
(
"utf-8"
)
if
is_mode
(
"debug"
)
then
if
is_mode
(
"debug"
)
then
add_cxflags
(
"-g -O0"
)
add_defines
(
"DEBUG_MODE"
)
add_defines
(
"DEBUG_MODE"
)
end
end
...
@@ -20,7 +19,7 @@ option("cpu")
...
@@ -20,7 +19,7 @@ option("cpu")
option_end
()
option_end
()
option
(
"omp"
)
option
(
"omp"
)
set_default
(
fals
e
)
set_default
(
tru
e
)
set_showmenu
(
true
)
set_showmenu
(
true
)
set_description
(
"Enable or disable OpenMP support for cpu kernel"
)
set_description
(
"Enable or disable OpenMP support for cpu kernel"
)
option_end
()
option_end
()
...
@@ -30,6 +29,10 @@ if has_config("cpu") then
...
@@ -30,6 +29,10 @@ if has_config("cpu") then
add_defines
(
"ENABLE_CPU_API"
)
add_defines
(
"ENABLE_CPU_API"
)
end
end
if
has_config
(
"omp"
)
then
add_defines
(
"ENABLE_OMP"
)
end
-- 英伟达
-- 英伟达
option
(
"nv-gpu"
)
option
(
"nv-gpu"
)
set_default
(
false
)
set_default
(
false
)
...
...
xmake/cpu.lua
View file @
67d81412
...
@@ -5,16 +5,21 @@ target("infiniop-cpu")
...
@@ -5,16 +5,21 @@ target("infiniop-cpu")
set_warnings
(
"all"
,
"error"
)
set_warnings
(
"all"
,
"error"
)
if
not
is_plat
(
"windows"
)
then
if
is_plat
(
"windows"
)
then
add_cxflags
(
"-fPIC"
)
if
has_config
(
"omp"
)
then
add_cxflags
(
"/openmp"
)
end
end
else
set_languages
(
"cxx17"
)
add_cxflags
(
"-fPIC"
)
add_files
(
"../src/infiniop/devices/cpu/*.cc"
,
"../src/infiniop/ops/*/cpu/*.cc"
)
if
has_config
(
"omp"
)
then
if
has_config
(
"omp"
)
then
add_cxflags
(
"-fopenmp"
)
add_cxflags
(
"-fopenmp"
)
add_ldflags
(
"-fopenmp"
)
add_ldflags
(
"-fopenmp"
)
end
end
end
set_languages
(
"cxx17"
)
add_files
(
"../src/infiniop/devices/cpu/*.cc"
,
"../src/infiniop/ops/*/cpu/*.cc"
)
target_end
()
target_end
()
target
(
"infinirt-cpu"
)
target
(
"infinirt-cpu"
)
...
...
xmake/cuda.lua
View file @
67d81412
...
@@ -21,6 +21,7 @@ target("infiniop-cuda")
...
@@ -21,6 +21,7 @@ target("infiniop-cuda")
if
is_plat
(
"windows"
)
then
if
is_plat
(
"windows"
)
then
add_cuflags
(
"-Xcompiler=/utf-8"
,
"--expt-relaxed-constexpr"
,
"--allow-unsupported-compiler"
)
add_cuflags
(
"-Xcompiler=/utf-8"
,
"--expt-relaxed-constexpr"
,
"--allow-unsupported-compiler"
)
add_cuflags
(
"-Xcompiler=/W3"
,
"-Xcompiler=/WX"
)
add_cuflags
(
"-Xcompiler=/W3"
,
"-Xcompiler=/WX"
)
add_cxxflags
(
"/FS"
)
if
CUDNN_ROOT
~=
nil
then
if
CUDNN_ROOT
~=
nil
then
add_linkdirs
(
CUDNN_ROOT
..
"
\\
lib\\x64"
)
add_linkdirs
(
CUDNN_ROOT
..
"
\\
lib\\x64"
)
end
end
...
@@ -46,6 +47,7 @@ target("infinirt-cuda")
...
@@ -46,6 +47,7 @@ target("infinirt-cuda")
if
is_plat
(
"windows"
)
then
if
is_plat
(
"windows"
)
then
add_cuflags
(
"-Xcompiler=/utf-8"
,
"--expt-relaxed-constexpr"
,
"--allow-unsupported-compiler"
)
add_cuflags
(
"-Xcompiler=/utf-8"
,
"--expt-relaxed-constexpr"
,
"--allow-unsupported-compiler"
)
add_cxxflags
(
"/FS"
)
else
else
add_cuflags
(
"-Xcompiler=-fPIC"
)
add_cuflags
(
"-Xcompiler=-fPIC"
)
add_culdflags
(
"-Xcompiler=-fPIC"
)
add_culdflags
(
"-Xcompiler=-fPIC"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment