Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
a5716a8c
Commit
a5716a8c
authored
Apr 22, 2025
by
PanZezhong
Browse files
issue/172 Add 算子 CPU & CUDA
parent
1622c975
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
444 additions
and
101 deletions
+444
-101
include/infiniop/ops/add.h
include/infiniop/ops/add.h
+4
-0
scripts/python_test.py
scripts/python_test.py
+1
-1
src/infiniop/elementwise/cuda/elementwise_cuda.cuh
src/infiniop/elementwise/cuda/elementwise_cuda.cuh
+5
-5
src/infiniop/ops/add/cpu/add_cpu.cc
src/infiniop/ops/add/cpu/add_cpu.cc
+52
-0
src/infiniop/ops/add/cpu/add_cpu.h
src/infiniop/ops/add/cpu/add_cpu.h
+19
-0
src/infiniop/ops/add/cuda/add_cuda.cu
src/infiniop/ops/add/cuda/add_cuda.cu
+57
-0
src/infiniop/ops/add/cuda/add_cuda.cuh
src/infiniop/ops/add/cuda/add_cuda.cuh
+8
-0
src/infiniop/ops/add/cuda/add_cuda_internal.cuh
src/infiniop/ops/add/cuda/add_cuda_internal.cuh
+26
-0
src/infiniop/ops/add/operator.cc
src/infiniop/ops/add/operator.cc
+118
-0
test/infiniop/add.py
test/infiniop/add.py
+154
-95
No files found.
include/infiniop/ops/add.h
View file @
a5716a8c
...
...
@@ -11,7 +11,11 @@ __C __export infiniStatus_t infiniopCreateAddDescriptor(infiniopHandle_t handle,
infiniopTensorDescriptor_t
a
,
infiniopTensorDescriptor_t
b
);
__C
__export
infiniStatus_t
infiniopGetAddWorkspaceSize
(
infiniopAddDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopAdd
(
infiniopAddDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
c
,
void
const
*
a
,
void
const
*
b
,
...
...
scripts/python_test.py
View file @
a5716a8c
...
...
@@ -12,7 +12,7 @@ os.chdir(PROJECT_DIR)
def
run_tests
(
args
):
failed
=
[]
for
test
in
[
"
causal_softmax
.py"
,
"
add
.py"
,
"gemm.py"
,
"random_sample.py"
,
"rms_norm.py"
,
...
...
src/infiniop/elementwise/cuda/elementwise_cuda.cuh
View file @
a5716a8c
...
...
@@ -208,7 +208,7 @@ struct DeviceImpl::Opaque {
* @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/
template
<
size
_t
BLOCK_SIZE
,
size_t
N
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
template
<
uint32
_t
BLOCK_SIZE
,
size_t
N
,
typename
Op
,
typename
Tdata
,
typename
...
Args
>
infiniStatus_t
calculateImpl
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
void
*
output
,
...
...
@@ -241,7 +241,7 @@ struct DeviceImpl::Opaque {
* @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/
template
<
size
_t
BLOCK_SIZE
,
size_t
N
,
typename
Op
,
typename
Tout
,
typename
...
Tin
,
typename
...
Args
,
template
<
uint32
_t
BLOCK_SIZE
,
size_t
N
,
typename
Op
,
typename
Tout
,
typename
...
Tin
,
typename
...
Args
,
std
::
enable_if_t
<
(
sizeof
...(
Tin
)
==
Op
::
num_inputs
),
int
>
=
0
>
infiniStatus_t
calculateImpl
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
...
...
@@ -329,7 +329,7 @@ private:
* @param args Additional arguments passed to the kernel.
* @return infiniStatus_t Status code indicating success or failure.
*/
template
<
size
_t
BLOCK_SIZE
,
size_t
N
,
typename
KernelFunc
,
typename
Tout
,
typename
...
Args
>
template
<
uint32
_t
BLOCK_SIZE
,
size_t
N
,
typename
KernelFunc
,
typename
Tout
,
typename
...
Args
>
infiniStatus_t
launchElementwiseKernel
(
const
op
::
elementwise
::
ElementwiseInfo
&
info
,
void
*
workspace
,
...
...
@@ -358,8 +358,8 @@ private:
d_output_shape
,
d_output_strides
,
d_input_shapes
,
d_input_strides
,
stream
));
dim3
blockDims
(
std
::
min
(
BLOCK_SIZE
,
static_cast
<
size
_t
>
(
internal
->
maxThreadsPerBlock
())));
dim3
gridDims
(
std
::
min
(
CEIL_DIV
(
output_size
,
blockDims
.
x
),
static_cast
<
size
_t
>
(
internal
->
gridSizeX
())));
dim3
blockDims
(
std
::
min
(
BLOCK_SIZE
,
static_cast
<
uint32
_t
>
(
internal
->
maxThreadsPerBlock
())));
dim3
gridDims
(
std
::
min
(
uint32_t
(
CEIL_DIV
(
output_size
,
blockDims
.
x
)
)
,
static_cast
<
uint32
_t
>
(
internal
->
gridSizeX
())));
size_t
step
=
gridDims
.
x
*
blockDims
.
x
;
for
(
size_t
i
=
0
;
i
<
output_size
;
i
+=
step
)
{
...
...
src/infiniop/ops/add/cpu/add_cpu.cc
0 → 100644
View file @
a5716a8c
#include "add_cpu.h"
namespace
op
::
add
::
cpu
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
cpu
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
a_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
b_desc
=
input_desc_vec
.
at
(
1
);
const
auto
&
c_shape
=
out_desc
->
shape
();
const
auto
&
a_shape
=
a_desc
->
shape
();
const
auto
&
b_shape
=
b_desc
->
shape
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
);
CHECK_SAME_SHAPE
(
c_shape
,
a_shape
,
b_shape
);
// create CPU elementwise descriptor
CREATE_ELEMENTWISE_CPU_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
AddOp
,
fp16_t
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
AddOp
,
float
>
(
_info
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
AddOp
,
double
>
(
_info
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::add::cpu
src/infiniop/ops/add/cpu/add_cpu.h
0 → 100644
View file @
a5716a8c
#ifndef __ADD_CPU_H__
#define __ADD_CPU_H__
#include "../../../elementwise/cpu/elementwise_cpu.h"
ELEMENTWISE_DESCRIPTOR
(
add
,
cpu
)
namespace
op
::
add
::
cpu
{
typedef
struct
AddOp
{
public:
static
constexpr
size_t
num_inputs
=
2
;
template
<
typename
T
>
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
+
b
;
}
}
AddOp
;
}
// namespace op::add::cpu
#endif // __ADD_CPU_H__
src/infiniop/ops/add/cuda/add_cuda.cu
0 → 100644
View file @
a5716a8c
#include "add_cuda.cuh"
#include "add_cuda_internal.cuh"
namespace
op
::
add
::
cuda
{
Descriptor
::~
Descriptor
()
=
default
;
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
std
::
vector
<
infiniopTensorDescriptor_t
>
input_desc_vec
)
{
auto
handle
=
reinterpret_cast
<
device
::
cuda
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
const
auto
&
a_desc
=
input_desc_vec
.
at
(
0
);
const
auto
&
b_desc
=
input_desc_vec
.
at
(
1
);
const
auto
&
c_shape
=
out_desc
->
shape
();
const
auto
&
a_shape
=
a_desc
->
shape
();
const
auto
&
b_shape
=
b_desc
->
shape
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_F32
,
INFINI_DTYPE_F64
);
CHECK_SAME_SHAPE
(
c_shape
,
a_shape
,
b_shape
);
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
(
handle
,
dtype
,
out_desc
,
input_desc_vec
)
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
output
,
std
::
vector
<
const
void
*>
inputs
,
void
*
stream
)
const
{
if
(
workspace_size
<
_workspace_size
)
{
return
INFINI_STATUS_INSUFFICIENT_WORKSPACE
;
}
switch
(
_dtype
)
{
case
INFINI_DTYPE_F16
:
return
_device_info
->
calculate
<
256
,
AddOp
,
half
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F32
:
return
_device_info
->
calculate
<
256
,
AddOp
,
float
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
case
INFINI_DTYPE_F64
:
return
_device_info
->
calculate
<
256
,
AddOp
,
double
>
(
_info
,
workspace
,
output
,
inputs
,
stream
);
default:
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::add::cuda
src/infiniop/ops/add/cuda/add_cuda.cuh
0 → 100644
View file @
a5716a8c
#ifndef __ADD_CUDA_API_H__
#define __ADD_CUDA_API_H__
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
ELEMENTWISE_DESCRIPTOR
(
add
,
cuda
)
#endif // __ADD_CUDA_API_H__
src/infiniop/ops/add/cuda/add_cuda_internal.cuh
0 → 100644
View file @
a5716a8c
#ifndef __ADD_CUDA_H__
#define __ADD_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_fp16.h>
namespace
op
::
add
::
cuda
{
typedef
struct
AddOp
{
public:
static
constexpr
size_t
num_inputs
=
2
;
template
<
typename
T
>
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
)
{
return
__hadd2
(
a
,
b
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
return
__hadd
(
a
,
b
);
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
return
__fadd_rd
(
a
,
b
);
}
else
{
return
a
+
b
;
}
}
}
AddOp
;
}
// namespace op::add::cuda
#endif // __ADD_CUDA_H__
src/infiniop/ops/add/operator.cc
0 → 100644
View file @
a5716a8c
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/add.h"
#ifdef ENABLE_CPU_API
#include "cpu/add_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/add_cuda.cuh"
#endif
__C
infiniStatus_t
infiniopCreateAddDescriptor
(
infiniopHandle_t
handle
,
infiniopAddDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
c_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::add::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::add::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \
{a_desc, \
b_desc})
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_CUDA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetAddWorkspaceSize
(
infiniopAddDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::add::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
GET
(
INFINI_DEVICE_CPU
,
cpu
)
#endif
#ifdef ENABLE_CUDA_API
GET
(
INFINI_DEVICE_NVIDIA
,
cuda
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopAdd
(
infiniopAddDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
c
,
const
void
*
a
,
const
void
*
b
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, c, {a, b}, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_CUDA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroyAddDescriptor
(
infiniopAddDescriptor_t
desc
)
{
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::add::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
DELETE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_CUDA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
cuda
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
test/infiniop/add.py
View file @
a5716a8c
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_void_p
import
torch
import
ctypes
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
".."
)))
from
operatorspy
import
(
open_lib
,
to_tensor
,
DeviceEnum
,
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_void_p
,
c_uint64
from
libinfiniop
import
(
infiniopHandle_t
,
infiniopTensorDescriptor_t
,
create_handle
,
destroy_handle
,
open_lib
,
to_tensor
,
get_test_devices
,
check_error
,
rearrange_if_needed
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
create_workspace
,
)
from
operatorspy.tests.test_utils
import
get_args
from
enum
import
Enum
,
auto
import
torch
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_
=
[
# shape, a_stride, b_stride, c_stride
((
13
,
4
),
None
,
None
,
None
),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
)),
((
13
,
4
),
(
0
,
1
),
None
,
None
),
((
13
,
4
,
4
),
None
,
None
,
None
),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
(
20
,
4
,
1
)),
((
13
,
4
,
4
),
(
4
,
0
,
1
),
(
0
,
4
,
1
),
None
),
((
16
,
5632
),
None
,
None
,
None
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
)),
((
4
,
4
,
5632
),
None
,
None
,
None
),
((
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
)),
]
class
Inplace
(
Enum
):
...
...
@@ -26,6 +43,34 @@ class Inplace(Enum):
INPLACE_B
=
auto
()
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE
=
[
Inplace
.
OUT_OF_PLACE
,
Inplace
.
INPLACE_A
,
Inplace
.
INPLACE_B
,
]
# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_
_TEST_CASES
=
[
test_case
+
(
inplace_item
,)
for
test_case
in
_TEST_CASES_
for
inplace_item
in
_INPLACE
]
# Data types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
1e-4
,
"rtol"
:
1e-2
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
class
AddDescriptor
(
Structure
):
_fields_
=
[(
"device"
,
c_int32
)]
...
...
@@ -37,42 +82,71 @@ def add(x, y):
return
torch
.
add
(
x
,
y
)
def
process_tensors
(
c
,
c_strides
,
a
,
a_stride
,
b
,
b_stride
,
inplace
):
"""
rearrange the tensors if needed and apply the inplace config.
if inplace is true and the output (i.e., c) is placed to the broadcasted input,
the inplace config is ignored and out-of-place is used
"""
original_c_strides
=
c_strides
if
c_strides
else
c
.
stride
()
def
_rearrange
(
tensor
,
strides
):
if
strides
and
0
in
strides
:
tensor
.
set_
(
tensor
.
untyped_storage
(),
0
,
tensor
.
shape
,
strides
)
return
tensor
else
:
return
rearrange_if_needed
(
tensor
,
strides
)
a
,
b
,
c
=
[
_rearrange
(
tensor
,
stride
)
for
tensor
,
stride
in
zip
([
a
,
b
,
c
],
[
a_stride
,
b_stride
,
c_strides
])
]
c
=
(
c
if
inplace
==
Inplace
.
OUT_OF_PLACE
else
(
a
if
inplace
==
Inplace
.
INPLACE_A
else
b
)
)
# if inplace is true and c has broadcasted config, reset it to the original unbroadcasted strides
if
0
in
c
.
stride
():
c
.
set_
(
c
.
untyped_storage
(),
0
,
c
.
shape
,
original_c_strides
)
return
a
,
b
,
c
def
test
(
lib
,
handle
,
torch_device
,
c_
shape
,
a_s
hap
e
,
b_s
hap
e
,
tensor_dtype
=
torch
.
float16
,
shape
,
a_s
tride
=
Non
e
,
b_s
tride
=
Non
e
,
c_stride
=
None
,
inplace
=
Inplace
.
OUT_OF_PLACE
,
dtype
=
torch
.
float16
,
sync
=
None
,
):
print
(
f
"Testing Add on
{
torch_device
}
with c_shape:
{
c_shape
}
a_shape:
{
a_shape
}
b_shape:
{
b_shape
}
dtype:
{
tensor_dtype
}
inplace:
{
inplace
.
name
}
"
f
"Testing Add on
{
torch_device
}
with shape:
{
shape
}
a_stride:
{
a_stride
}
b_stride:
{
b_stride
}
c_stride:
{
c_stride
}
"
f
"dtype:
{
dtype
}
inplace:
{
inplace
}
"
)
if
a_shape
!=
b_shape
and
inplace
!=
Inplace
.
OUT_OF_PLACE
:
print
(
"Unsupported test: broadcasting does not support in-place"
)
return
a
=
torch
.
rand
(
a_shape
,
dtype
=
tensor_dtype
).
to
(
torch_device
)
b
=
torch
.
rand
(
b_shape
,
dtype
=
tensor_dtype
).
to
(
torch_device
)
c
=
(
torch
.
rand
(
c_shape
,
dtype
=
tensor_dtype
).
to
(
torch_device
)
if
inplace
==
Inplace
.
OUT_OF_PLACE
else
(
a
if
inplace
==
Inplace
.
INPLACE_A
else
b
)
)
a
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
b
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
c
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
a
,
b
,
c
=
process_tensors
(
c
,
c_stride
,
a
,
a_stride
,
b
,
b_stride
,
inplace
)
ans
=
add
(
a
,
b
)
a_tensor
=
to_tensor
(
a
,
lib
)
b_tensor
=
to_tensor
(
b
,
lib
)
a_tensor
,
b_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
a
,
b
]]
c_tensor
=
(
to_tensor
(
c
,
lib
)
if
inplace
==
Inplace
.
OUT_OF_PLACE
else
(
a_tensor
if
inplace
==
Inplace
.
INPLACE_A
else
b_tensor
)
)
descriptor
=
infiniopAddDescriptor_t
()
if
sync
is
not
None
:
sync
()
descriptor
=
infiniopAddDescriptor_t
()
check_error
(
lib
.
infiniopCreateAddDescriptor
(
handle
,
...
...
@@ -84,74 +158,48 @@ def test(
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
c_tensor
.
descriptor
.
contents
.
invalidate
()
a_tensor
.
descriptor
.
contents
.
invalidate
()
b_tensor
.
descriptor
.
contents
.
invalidate
()
for
tensor
in
[
a_tensor
,
b_tensor
,
c_tensor
]:
tensor
.
destroyDesc
(
lib
)
workspace_size
=
c_uint64
(
0
)
check_error
(
lib
.
infiniop
Add
(
descriptor
,
c_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
Non
e
)
lib
.
infiniop
GetAddWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_siz
e
)
)
)
assert
torch
.
allclose
(
c
,
ans
,
atol
=
0
,
rtol
=
1e-3
)
check_error
(
lib
.
infiniopDestroyAddDescriptor
(
descriptor
))
def
test_cpu
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CPU
handle
=
create_handle
(
lib
,
device
)
for
c_shape
,
a_shape
,
b_shape
,
inplace
in
test_cases
:
# fmt: off
test
(
lib
,
handle
,
"cpu"
,
c_shape
,
a_shape
,
b_shape
,
tensor_dtype
=
torch
.
float16
,
inplace
=
inplace
)
test
(
lib
,
handle
,
"cpu"
,
c_shape
,
a_shape
,
b_shape
,
tensor_dtype
=
torch
.
float32
,
inplace
=
inplace
)
# fmt: on
destroy_handle
(
lib
,
handle
)
workspace
=
create_workspace
(
workspace_size
.
value
,
c
.
device
)
def
lib_add
():
check_error
(
lib
.
infiniopAdd
(
descriptor
,
workspace
.
data_ptr
()
if
workspace
is
not
None
else
None
,
workspace_size
.
value
,
c_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
None
,
)
)
def
test_cuda
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CUDA
handle
=
create_handle
(
lib
,
device
)
for
c_shape
,
a_shape
,
b_shape
,
inplace
in
test_cases
:
# fmt: off
test
(
lib
,
handle
,
"cuda"
,
c_shape
,
a_shape
,
b_shape
,
tensor_dtype
=
torch
.
float16
,
inplace
=
inplace
)
test
(
lib
,
handle
,
"cuda"
,
c_shape
,
a_shape
,
b_shape
,
tensor_dtype
=
torch
.
float32
,
inplace
=
inplace
)
# fmt: on
destroy_handle
(
lib
,
handle
)
lib_add
()
def
test_bang
(
lib
,
test_cases
):
import
torch_mlu
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
c
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
c
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
device
=
DeviceEnum
.
DEVICE_BANG
handle
=
create_handle
(
lib
,
device
)
for
c_shape
,
a_shape
,
b_shape
,
inplace
in
test_cases
:
# Profiling workflow
if
PROFILE
:
# fmt: off
test
(
lib
,
handle
,
"mlu"
,
c_shape
,
a_shape
,
b_shape
,
tensor_dtype
=
torch
.
float16
,
inplace
=
inplace
)
test
(
lib
,
handle
,
"mlu"
,
c_shape
,
a_shape
,
b_shape
,
tensor_dtype
=
torch
.
float32
,
inplace
=
inplace
)
profile_operation
(
"PyTorch"
,
lambda
:
add
(
a
,
b
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_add
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
destroy_handle
(
lib
,
handle
)
check_error
(
lib
.
infiniopDestroyAddDescriptor
(
descriptor
)
)
if
__name__
==
"__main__"
:
test_cases
=
[
# fmt: off
# c_shape, a_shape, b_shape, inplace
# ((32, 150, 512000), (32, 150, 512000), (32, 150, 512000), Inplace.OUT_OF_PLACE),
# ((32, 150, 51200), (32, 150, 51200), (32, 150, 1), Inplace.OUT_OF_PLACE),
# ((32, 150, 51200), (32, 150, 51200), (32, 150, 51200), Inplace.OUT_OF_PLACE),
((
1
,
3
),
(
1
,
3
),
(
1
,
3
),
Inplace
.
OUT_OF_PLACE
),
((),
(),
(),
Inplace
.
OUT_OF_PLACE
),
((
3
,
3
),
(
3
,
3
),
(
3
,
3
),
Inplace
.
OUT_OF_PLACE
),
((
2
,
20
,
3
),
(
2
,
1
,
3
),
(
2
,
20
,
3
),
Inplace
.
OUT_OF_PLACE
),
((
32
,
20
,
512
),
(
32
,
20
,
512
),
(
32
,
20
,
512
),
Inplace
.
INPLACE_A
),
((
32
,
20
,
512
),
(
32
,
20
,
512
),
(
32
,
20
,
512
),
Inplace
.
INPLACE_B
),
((
32
,
256
,
112
,
112
),
(
32
,
256
,
112
,
1
),
(
32
,
256
,
112
,
112
),
Inplace
.
OUT_OF_PLACE
),
((
32
,
256
,
112
,
112
),
(
32
,
256
,
112
,
112
),
(
32
,
256
,
112
,
112
),
Inplace
.
OUT_OF_PLACE
),
((
2
,
4
,
3
),
(
2
,
1
,
3
),
(
4
,
3
),
Inplace
.
OUT_OF_PLACE
),
((
2
,
3
,
4
,
5
),
(
2
,
3
,
4
,
5
),
(
5
,),
Inplace
.
OUT_OF_PLACE
),
((
3
,
2
,
4
,
5
),
(
4
,
5
),
(
3
,
2
,
1
,
1
),
Inplace
.
OUT_OF_PLACE
),
# fmt: on
]
args
=
get_args
()
lib
=
open_lib
()
lib
.
infiniopCreateAddDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateAddDescriptor
.
argtypes
=
[
infiniopHandle_t
,
...
...
@@ -160,25 +208,36 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
]
lib
.
infiniopGetAddWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetAddWorkspaceSize
.
argtypes
=
[
infiniopAddDescriptor_t
,
POINTER
(
c_uint64
),
]
lib
.
infiniopAdd
.
restype
=
c_int32
lib
.
infiniopAdd
.
argtypes
=
[
infiniopAddDescriptor_t
,
c_void_p
,
c_uint64
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
lib
.
infiniopDestroyAddDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyAddDescriptor
.
argtypes
=
[
infiniopAddDescriptor_t
,
]
if
args
.
cpu
:
test_cpu
(
lib
,
test_cases
)
if
args
.
cuda
:
test_cuda
(
lib
,
test_cases
)
if
args
.
bang
:
test_bang
(
lib
,
test_cases
)
if
not
(
args
.
cpu
or
args
.
cuda
or
args
.
bang
):
test_cpu
(
lib
,
test_cases
)
# Configure testing options
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment