Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
ca2f34cf
Commit
ca2f34cf
authored
Feb 20, 2025
by
xgqdut2016
Browse files
issue/66: modified test py
parent
87d10975
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
544 additions
and
537 deletions
+544
-537
test/infiniop/causal_softmax.py
test/infiniop/causal_softmax.py
+96
-83
test/infiniop/random_sample.py
test/infiniop/random_sample.py
+75
-87
test/infiniop/rearrange.py
test/infiniop/rearrange.py
+87
-85
test/infiniop/rms_norm.py
test/infiniop/rms_norm.py
+98
-103
test/infiniop/rotary_embedding.py
test/infiniop/rotary_embedding.py
+50
-30
test/infiniop/swiglu.py
test/infiniop/swiglu.py
+138
-149
No files found.
test/infiniop/causal_softmax.py
View file @
ca2f34cf
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_uint64
,
c_void_p
import
torch
import
ctypes
import
ctypes
import
sys
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_size_t
,
c_uint64
,
c_void_p
,
c_float
import
os
from
libinfiniop
import
(
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
".."
)))
from
operatorspy
import
(
open_lib
,
to_tensor
,
DeviceEnum
,
infiniopHandle_t
,
infiniopHandle_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
create_handle
,
open_lib
,
destroy_handle
,
to_tensor
,
get_test_devices
,
check_error
,
check_error
,
rearrange_
tensor
,
rearrange_
if_needed
,
create_workspace
,
create_workspace
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
)
)
from
operatorspy.tests.test_utils
import
get_args
# ==============================================================================
import
torch
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES
=
[
# x_shape, x_stride
((
32
,
512
),
None
),
((
32
,
512
),
(
1024
,
1
)),
((
32
,
5
,
5
),
None
),
((
32
,
20
,
512
),
None
),
((
32
,
20
,
512
),
(
20480
,
512
,
1
)),
# Ascend 暂不支持非连续
((
32
,
20
,
4
,
512
),
None
),
((
32
,
20
,
4
,
512
),
(
81920
,
2048
,
512
,
1
)),
]
# Data types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
'atol'
:
0
,
'rtol'
:
1e-2
},
torch
.
float32
:
{
'atol'
:
0
,
'rtol'
:
1e-3
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
class
CausalSoftmaxDescriptor
(
Structure
):
class
CausalSoftmaxDescriptor
(
Structure
):
_fields_
=
[(
"device"
,
c_int32
)]
_fields_
=
[(
"device"
,
c_int32
)]
...
@@ -37,88 +61,78 @@ def causal_softmax(x):
...
@@ -37,88 +61,78 @@ def causal_softmax(x):
return
torch
.
nn
.
functional
.
softmax
(
masked
,
dim
=-
1
).
to
(
type
)
return
torch
.
nn
.
functional
.
softmax
(
masked
,
dim
=-
1
).
to
(
type
)
def
test
(
lib
,
handle
,
torch_device
,
x_shape
,
x_stride
=
None
,
x_dtype
=
torch
.
float16
):
def
test
(
lib
,
handle
,
torch_device
,
x_shape
,
x_stride
=
None
,
dtype
=
torch
.
float16
):
print
(
print
(
f
"Testing CausalSoftmax on
{
torch_device
}
with x_shape:
{
x_shape
}
x_stride:
{
x_stride
}
dtype:
{
x_
dtype
}
"
f
"Testing CausalSoftmax on
{
torch_device
}
with x_shape:
{
x_shape
}
x_stride:
{
x_stride
}
dtype:
{
dtype
}
"
)
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
x_dtype
).
to
(
torch_device
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
dtype
).
to
(
torch_device
)
if
x_stride
is
not
None
:
x
=
rearrange_tensor
(
x
,
x_stride
)
ans
=
causal_softmax
(
x
)
ans
=
causal_softmax
(
x
)
x
=
rearrange_if_needed
(
x
,
x_stride
)
x_tensor
=
to_tensor
(
x
,
lib
)
x_tensor
=
to_tensor
(
x
,
lib
)
descriptor
=
infiniopCausalSoftmaxDescriptor_t
()
descriptor
=
infiniopCausalSoftmaxDescriptor_t
()
check_error
(
check_error
(
lib
.
infiniopCreateCausalSoftmaxDescriptor
(
lib
.
infiniopCreateCausalSoftmaxDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
x_tensor
.
descriptor
handle
,
ctypes
.
byref
(
descriptor
),
x_tensor
.
descriptor
)
)
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor
.
descriptor
.
contents
.
invalidate
()
workspace_size
=
c_uint64
(
0
)
workspace_size
=
c_uint64
(
0
)
check_error
(
check_error
(
lib
.
infiniopGetCausalSoftmaxWorkspaceSize
(
lib
.
infiniopGetCausalSoftmaxWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor
.
descriptor
.
contents
.
invalidate
()
workspace
=
create_workspace
(
workspace_size
.
value
,
x
.
device
)
workspace
=
create_workspace
(
workspace_size
.
value
,
x
.
device
)
check_error
(
def
lib_causal_softmax
():
lib
.
infiniopCausalSoftmax
(
check_error
(
descriptor
,
lib
.
infiniopCausalSoftmax
(
workspace
.
data_ptr
()
if
workspace
is
not
None
else
None
,
descriptor
,
workspace_size
.
value
,
workspace
.
data_ptr
()
if
workspace
is
not
None
else
None
,
x_tensor
.
data
,
workspace_size
.
value
,
None
,
x_tensor
.
data
,
None
,
)
)
)
)
lib_causal_softmax
()
assert
torch
.
allclose
(
x
,
ans
,
atol
=
0
,
rtol
=
1e-2
)
check_error
(
lib
.
infiniopDestroyCausalSoftmaxDescriptor
(
descriptor
))
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
x
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
def
test_cpu
(
lib
,
test_cases
):
assert
torch
.
allclose
(
x
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
device
=
DeviceEnum
.
DEVICE_CPU
handle
=
create_handle
(
lib
,
device
)
# Profiling workflow
for
x_shape
,
x_stride
in
test_cases
:
if
PROFILE
:
test
(
lib
,
handle
,
"cpu"
,
x_shape
,
x_stride
)
# fmt: off
destroy_handle
(
lib
,
handle
)
profile_operation
(
"PyTorch"
,
lambda
:
causal_softmax
(
x
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_causal_softmax
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
def
test_cuda
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CUDA
handle
=
create_handle
(
lib
,
device
)
for
x_shape
,
x_stride
in
test_cases
:
test
(
lib
,
handle
,
"cuda"
,
x_shape
,
x_stride
)
destroy_handle
(
lib
,
handle
)
def
test_bang
(
lib
,
test_cases
):
import
torch_mlu
device
=
DeviceEnum
.
DEVICE_BANG
handle
=
create_handle
(
lib
,
device
)
for
x_shape
,
x_stride
in
test_cases
:
test
(
lib
,
handle
,
"mlu"
,
x_shape
,
x_stride
)
destroy_handle
(
lib
,
handle
)
check_error
(
lib
.
infiniopDestroyCausalSoftmaxDescriptor
(
descriptor
))
def
test_ascend
(
lib
,
test_cases
):
import
torch_npu
device
=
DeviceEnum
.
DEVICE_ASCEND
handle
=
create_handle
(
lib
,
device
)
for
x_shape
,
x_stride
in
test_cases
:
test
(
lib
,
handle
,
"npu"
,
x_shape
,
x_stride
)
destroy_handle
(
lib
,
handle
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_cases
=
[
# x_shape, x_stride
((
32
,
20
,
512
),
None
),
((
32
,
20
,
512
),
(
20480
,
512
,
1
)),
# Ascend 暂不支持非连续
]
args
=
get_args
()
args
=
get_args
()
lib
=
open_lib
()
lib
=
open_lib
()
lib
.
infiniopCreateCausalSoftmaxDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateCausalSoftmaxDescriptor
.
restype
=
c_int32
...
@@ -144,15 +158,14 @@ if __name__ == "__main__":
...
@@ -144,15 +158,14 @@ if __name__ == "__main__":
lib
.
infiniopDestroyCausalSoftmaxDescriptor
.
argtypes
=
[
lib
.
infiniopDestroyCausalSoftmaxDescriptor
.
argtypes
=
[
infiniopCausalSoftmaxDescriptor_t
,
infiniopCausalSoftmaxDescriptor_t
,
]
]
# Configure testing options
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
if
args
.
cpu
:
test_cpu
(
lib
,
test_cases
)
if
args
.
cuda
:
test_cuda
(
lib
,
test_cases
)
if
args
.
bang
:
test_bang
(
lib
,
test_cases
)
if
args
.
ascend
:
test_ascend
(
lib
,
test_cases
)
if
not
(
args
.
cpu
or
args
.
cuda
or
args
.
bang
or
args
.
ascend
):
test_cpu
(
lib
,
test_cases
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/random_sample.py
View file @
ca2f34cf
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_uint64
,
c_void_p
,
c_float
import
torch
import
ctypes
import
ctypes
import
sys
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_size_t
,
c_uint64
,
c_void_p
,
c_float
import
os
from
libinfiniop
import
(
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
".."
)))
from
operatorspy
import
(
open_lib
,
to_tensor
,
DeviceEnum
,
infiniopHandle_t
,
infiniopHandle_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
create_handle
,
open_lib
,
destroy_handle
,
to_tensor
,
get_test_devices
,
check_error
,
check_error
,
rearrange_
tensor
,
rearrange_
if_needed
,
create_workspace
,
create_workspace
,
U64
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
)
)
from
operatorspy.tests.test_utils
import
get_args
# ==============================================================================
import
torch
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES
=
[
# voc, random_val, topp, topk, temperature
(
512
,
0.8
,
0.8
,
3
,
0.5
),
(
4096
,
0.05
,
0.9
,
5
,
1.0
),
(
16384
,
0.15
,
0.85
,
10
,
2.0
),
(
512
,
0.08
,
0
,
3
,
0.5
),
(
4096
,
0.5
,
0.9
,
1
,
1.0
),
(
16384
,
0.15
,
0
,
1
,
2.0
),
(
16384
,
0.15
,
0
,
1
,
2.0
),
(
32000
,
0.08
,
0.8
,
50
,
1.0
),
(
32000
,
0.08
,
1.0
,
25
,
1.0
),
# (119696, 0.01, 1.0, 100, 1.0),
]
# Data types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
class
RandomSampleDescriptor
(
Structure
):
class
RandomSampleDescriptor
(
Structure
):
...
@@ -116,8 +138,8 @@ def test(
...
@@ -116,8 +138,8 @@ def test(
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_
tensor
.
descriptor
.
contents
.
invalidate
()
for
tensor
in
[
x_tensor
,
indices_tensor
]:
indices_
tensor
.
descriptor
.
contents
.
invalidate
()
tensor
.
descriptor
.
contents
.
invalidate
()
workspace_size
=
c_uint64
(
0
)
workspace_size
=
c_uint64
(
0
)
check_error
(
check_error
(
...
@@ -126,77 +148,45 @@ def test(
...
@@ -126,77 +148,45 @@ def test(
)
)
)
)
workspace
=
create_workspace
(
workspace_size
.
value
,
torch_device
)
workspace
=
create_workspace
(
workspace_size
.
value
,
torch_device
)
check_error
(
lib
.
infiniopRandomSample
(
def
lib_random_sample
():
descriptor
,
check_error
(
workspace
.
data_ptr
()
if
workspace
is
not
None
else
None
,
lib
.
infiniopRandomSample
(
workspace_size
.
value
,
descriptor
,
indices_tensor
.
data
,
workspace
.
data_ptr
()
if
workspace
is
not
None
else
None
,
x_tensor
.
data
,
workspace_size
.
value
,
random_val
,
indices_tensor
.
data
,
topp
,
x_tensor
.
data
,
topk
,
random_val
,
temperature
,
topp
,
None
,
topk
,
temperature
,
None
,
)
)
)
)
if
torch_device
==
"npu"
:
if
torch_device
==
"npu"
:
torch
.
npu
.
synchronize
()
torch
.
npu
.
synchronize
()
assert
indices
[
0
].
type
(
ans
.
dtype
)
==
ans
or
data
[
ans
]
==
data
[
indices
[
0
]]
assert
indices
[
0
].
type
(
ans
.
dtype
)
==
ans
or
data
[
ans
]
==
data
[
indices
[
0
]]
# Profiling workflow
if
PROFILE
:
# fmt: off
if
topp
>
0
and
topk
>
1
:
profile_operation
(
"PyTorch"
,
lambda
:
random_sample
(
data
.
to
(
"cpu"
),
random_val
,
topp
,
topk
,
voc
,
temperature
,
"cpu"
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
else
:
profile_operation
(
"PyTorch"
,
lambda
:
random_sample_0
(
data
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_random_sample
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
lib
.
infiniopDestroyRandomSampleDescriptor
(
descriptor
))
check_error
(
lib
.
infiniopDestroyRandomSampleDescriptor
(
descriptor
))
def
test_cpu
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CPU
handle
=
create_handle
(
lib
,
device
)
for
voc
,
random_val
,
topp
,
topk
,
temperature
in
test_cases
:
test
(
lib
,
handle
,
"cpu"
,
voc
,
random_val
,
topp
,
topk
,
temperature
)
destroy_handle
(
lib
,
handle
)
def
test_cuda
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CUDA
handle
=
create_handle
(
lib
,
device
)
for
voc
,
random_val
,
topp
,
topk
,
temperature
in
test_cases
:
test
(
lib
,
handle
,
"cuda"
,
voc
,
random_val
,
topp
,
topk
,
temperature
)
destroy_handle
(
lib
,
handle
)
def
test_bang
(
lib
,
test_cases
):
import
torch_mlu
device
=
DeviceEnum
.
DEVICE_BANG
handle
=
create_handle
(
lib
,
device
)
for
voc
,
random_val
,
topp
,
topk
,
temperature
in
test_cases
:
test
(
lib
,
handle
,
"mlu"
,
voc
,
random_val
,
topp
,
topk
,
temperature
)
destroy_handle
(
lib
,
handle
)
def
test_ascend
(
lib
,
test_cases
):
import
torch_npu
device
=
DeviceEnum
.
DEVICE_ASCEND
handle
=
create_handle
(
lib
,
device
)
for
voc
,
random_val
,
topp
,
topk
,
temperature
in
test_cases
:
test
(
lib
,
handle
,
"npu"
,
voc
,
random_val
,
topp
,
topk
,
temperature
)
destroy_handle
(
lib
,
handle
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_cases
=
[
# voc, random_val, topp, topk, temperature
(
512
,
0.8
,
0.8
,
3
,
0.5
),
(
4096
,
0.05
,
0.9
,
5
,
1.0
),
(
16384
,
0.15
,
0.85
,
10
,
2.0
),
(
512
,
0.08
,
0
,
3
,
0.5
),
(
4096
,
0.5
,
0.9
,
1
,
1.0
),
(
16384
,
0.15
,
0
,
1
,
2.0
),
(
16384
,
0.15
,
0
,
1
,
2.0
),
(
32000
,
0.08
,
0.8
,
50
,
1.0
),
(
32000
,
0.08
,
1.0
,
25
,
1.0
),
# (119696, 0.01, 1.0, 100, 1.0),
]
args
=
get_args
()
args
=
get_args
()
lib
=
open_lib
()
lib
=
open_lib
()
...
@@ -229,14 +219,12 @@ if __name__ == "__main__":
...
@@ -229,14 +219,12 @@ if __name__ == "__main__":
infiniopRandomSampleDescriptor_t
,
infiniopRandomSampleDescriptor_t
,
]
]
if
args
.
cpu
:
PROFILE
=
args
.
profile
test_cpu
(
lib
,
test_cases
)
NUM_PRERUN
=
args
.
num_prerun
if
args
.
cuda
:
NUM_ITERATIONS
=
args
.
num_iterations
test_cuda
(
lib
,
test_cases
)
if
args
.
bang
:
# Execute tests
test_bang
(
lib
,
test_cases
)
for
device
in
get_test_devices
(
args
):
if
args
.
ascend
:
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
test_ascend
(
lib
,
test_cases
)
if
not
(
args
.
cpu
or
args
.
cuda
or
args
.
bang
or
args
.
ascend
):
test_cpu
(
lib
,
test_cases
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/rearrange.py
View file @
ca2f34cf
import
torch
import
ctypes
import
ctypes
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_uint64
,
c_void_p
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_size_t
,
c_uint64
,
c_void_p
,
c_float
import
sys
from
libinfiniop
import
(
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
".."
)))
from
operatorspy
import
(
open_lib
,
to_tensor
,
CTensor
,
DeviceEnum
,
infiniopHandle_t
,
infiniopHandle_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
create_handle
,
open_lib
,
destroy_handle
,
to_tensor
,
get_test_devices
,
check_error
,
check_error
,
rearrange_tensor
,
rearrange_if_needed
,
create_workspace
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
)
)
from
operatorspy.tests.test_utils
import
get_args
# ==============================================================================
import
torch
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES
=
[
# ((src_shape, src_stride), (dst_shape, dst_stride))
(((
2
,
4
,
32
),
None
),
((
2
,
4
,
32
),
(
256
,
64
,
1
))),
(((
32
,
6
,
64
),
(
64
,
2560
,
1
)),
((
32
,
6
,
64
),
None
)),
(((
4
,
6
,
64
),
(
64
,
2560
,
1
)),
((
4
,
6
,
64
),
(
131072
,
64
,
1
))),
(((
1
,
32
,
64
),
(
2048
,
64
,
1
)),
((
1
,
32
,
64
),
(
2048
,
64
,
1
))),
(((
32
,
1
,
64
),
(
64
,
2560
,
1
)),
((
32
,
1
,
64
),
(
64
,
64
,
1
))),
(((
4
,
1
,
64
),
(
64
,
2560
,
1
)),
((
4
,
1
,
64
),
(
64
,
11264
,
1
))),
(((
64
,),
(
1
,)),
((
64
,),
(
1
,))),
]
# Data types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
0
,
"rtol"
:
1e-3
},
torch
.
float32
:
{
"atol"
:
0
,
"rtol"
:
1e-3
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
class
RerrangeDescriptor
(
Structure
):
class
RerrangeDescriptor
(
Structure
):
...
@@ -43,12 +70,13 @@ def test(
...
@@ -43,12 +70,13 @@ def test(
)
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
x_dtype
).
to
(
torch_device
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
x_dtype
).
to
(
torch_device
)
y
=
torch
.
zeros
(
y_shape
,
dtype
=
x_dtype
).
to
(
torch_device
)
y
=
torch
.
zeros
(
y_shape
,
dtype
=
x_dtype
).
to
(
torch_device
)
if
x_stride
is
not
None
:
x
=
rearrange_tensor
(
x
,
x_stride
)
x
,
y
=
[
if
y_stride
is
not
None
:
rearrange_if_needed
(
tensor
,
stride
)
y
=
rearrange_tensor
(
y
,
y_stride
)
for
tensor
,
stride
in
zip
([
x
,
y
],
[
x_stride
,
y_stride
])
x_tensor
=
to_tensor
(
x
,
lib
)
]
y_tensor
=
to_tensor
(
y
,
lib
)
x_tensor
,
y_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
x
,
y
]]
descriptor
=
infiniopRearrangeDescriptor_t
()
descriptor
=
infiniopRearrangeDescriptor_t
()
check_error
(
check_error
(
...
@@ -58,71 +86,42 @@ def test(
...
@@ -58,71 +86,42 @@ def test(
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor
.
descriptor
.
contents
.
invalidate
()
for
tensor
in
[
x_tensor
,
y_tensor
]:
y_tensor
.
descriptor
.
contents
.
invalidate
()
tensor
.
descriptor
.
contents
.
invalidate
()
def
lib_rearrange
():
check_error
(
lib
.
infiniopRearrange
(
descriptor
,
y_tensor
.
data
,
x_tensor
.
data
,
None
)
)
lib_rearrange
()
# Validate results
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
x
,
y
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
x
,
y
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
rearrange_tensor
(
y
,
y_stride
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_rearrange
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
lib
.
infiniopRearrange
(
descriptor
,
y_tensor
.
data
,
x_tensor
.
data
,
None
))
assert
torch
.
allclose
(
x
,
y
,
atol
=
0
,
rtol
=
1e-3
)
check_error
(
lib
.
infiniopDestroyRearrangeDescriptor
(
descriptor
))
check_error
(
lib
.
infiniopDestroyRearrangeDescriptor
(
descriptor
))
def
test_cpu
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CPU
handle
=
create_handle
(
lib
,
device
)
for
test_case
in
test_cases
:
x_shape
,
x_stride
=
test_case
[
0
]
y_shape
,
y_stride
=
test_case
[
1
]
test
(
lib
,
handle
,
"cpu"
,
x_shape
,
x_stride
,
y_shape
,
y_stride
)
destroy_handle
(
lib
,
handle
)
def
test_cuda
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CUDA
handle
=
create_handle
(
lib
,
device
)
for
test_case
in
test_cases
:
x_shape
,
x_stride
=
test_case
[
0
]
y_shape
,
y_stride
=
test_case
[
1
]
test
(
lib
,
handle
,
"cuda"
,
x_shape
,
x_stride
,
y_shape
,
y_stride
)
destroy_handle
(
lib
,
handle
)
def
test_bang
(
lib
,
test_cases
):
import
torch_mlu
device
=
DeviceEnum
.
DEVICE_BANG
handle
=
create_handle
(
lib
,
device
)
for
test_case
in
test_cases
:
x_shape
,
x_stride
=
test_case
[
0
]
y_shape
,
y_stride
=
test_case
[
1
]
test
(
lib
,
handle
,
"mlu"
,
x_shape
,
x_stride
,
y_shape
,
y_stride
)
destroy_handle
(
lib
,
handle
)
def
test_ascend
(
lib
,
test_cases
):
import
torch_npu
device
=
DeviceEnum
.
DEVICE_ASCEND
handle
=
create_handle
(
lib
,
device
)
for
test_case
in
test_cases
:
x_shape
,
x_stride
=
test_case
[
0
]
y_shape
,
y_stride
=
test_case
[
1
]
test
(
lib
,
handle
,
"npu"
,
x_shape
,
x_stride
,
y_shape
,
y_stride
)
destroy_handle
(
lib
,
handle
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
get_args
()
args
=
get_args
()
test_cases
=
[
# ((src_shape, src_stride), (dst_shape, dst_stride))
(((
2
,
4
,
32
),
None
),
((
2
,
4
,
32
),
(
256
,
64
,
1
))),
(((
32
,
6
,
64
),
(
64
,
2560
,
1
)),
((
32
,
6
,
64
),
None
)),
(((
4
,
6
,
64
),
(
64
,
2560
,
1
)),
((
4
,
6
,
64
),
(
131072
,
64
,
1
))),
(((
1
,
32
,
64
),
(
2048
,
64
,
1
)),
((
1
,
32
,
64
),
(
2048
,
64
,
1
))),
(((
32
,
1
,
64
),
(
64
,
2560
,
1
)),
((
32
,
1
,
64
),
(
64
,
64
,
1
))),
(((
4
,
1
,
64
),
(
64
,
2560
,
1
)),
((
4
,
1
,
64
),
(
64
,
11264
,
1
))),
(((
64
,),
(
1
,)),
((
64
,),
(
1
,))),
]
lib
=
open_lib
()
lib
=
open_lib
()
lib
.
infiniopCreateRearrangeDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateRearrangeDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateRearrangeDescriptor
.
argtypes
=
[
lib
.
infiniopCreateRearrangeDescriptor
.
argtypes
=
[
infiniopHandle_t
,
infiniopHandle_t
,
...
@@ -139,12 +138,15 @@ if __name__ == "__main__":
...
@@ -139,12 +138,15 @@ if __name__ == "__main__":
]
]
lib
.
infiniopDestroyRearrangeDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyRearrangeDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyRearrangeDescriptor
.
argtypes
=
[
infiniopRearrangeDescriptor_t
]
lib
.
infiniopDestroyRearrangeDescriptor
.
argtypes
=
[
infiniopRearrangeDescriptor_t
]
if
args
.
cpu
:
test_cpu
(
lib
,
test_cases
)
# Configure testing options
if
args
.
cuda
:
DEBUG
=
args
.
debug
test_cuda
(
lib
,
test_cases
)
PROFILE
=
args
.
profile
if
args
.
bang
:
NUM_PRERUN
=
args
.
num_prerun
test_bang
(
lib
,
test_cases
)
NUM_ITERATIONS
=
args
.
num_iterations
if
args
.
ascend
:
test_ascend
(
lib
,
test_cases
)
# Execute tests
for
device
in
get_test_devices
(
args
):
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/rms_norm.py
View file @
ca2f34cf
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_uint64
,
c_void_p
,
c_float
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_uint64
,
c_void_p
,
c_float
import
ctypes
import
ctypes
import
sys
import
torch
import
os
import
ctypes
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_size_t
,
c_uint64
,
c_void_p
,
c_float
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
".."
)))
from
libinfiniop
import
(
from
operatorspy
import
(
open_lib
,
to_tensor
,
DeviceEnum
,
infiniopHandle_t
,
infiniopHandle_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
create_handle
,
open_lib
,
destroy_handle
,
to_tensor
,
get_test_devices
,
check_error
,
check_error
,
rearrange_
tensor
,
rearrange_
if_needed
,
create_workspace
,
create_workspace
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
)
)
from
operatorspy.tests.test_utils
import
get_args
# ==============================================================================
import
torch
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES
=
[
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
torch
.
float32
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
torch
.
float16
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
torch
.
float32
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
torch
.
float16
),
]
# x types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
0
,
"rtol"
:
1e-2
},
torch
.
float32
:
{
"atol"
:
0
,
"rtol"
:
1e-3
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
class
RMSNormDescriptor
(
Structure
):
class
RMSNormDescriptor
(
Structure
):
_fields_
=
[(
"device"
,
c_int32
)]
_fields_
=
[(
"device"
,
c_int32
)]
...
@@ -27,7 +51,6 @@ class RMSNormDescriptor(Structure):
...
@@ -27,7 +51,6 @@ class RMSNormDescriptor(Structure):
infiniopRMSNormDescriptor_t
=
POINTER
(
RMSNormDescriptor
)
infiniopRMSNormDescriptor_t
=
POINTER
(
RMSNormDescriptor
)
def
rms_norm
(
x
,
w
,
eps
):
def
rms_norm
(
x
,
w
,
eps
):
input_dtype
=
x
.
dtype
input_dtype
=
x
.
dtype
hidden_states
=
x
.
to
(
torch
.
float32
)
hidden_states
=
x
.
to
(
torch
.
float32
)
...
@@ -37,19 +60,18 @@ def rms_norm(x, w, eps):
...
@@ -37,19 +60,18 @@ def rms_norm(x, w, eps):
def
test
(
def
test
(
lib
,
lib
,
handle
,
handle
,
torch_device
,
torch_device
,
y_shape
,
y_shape
,
x_shape
,
x_shape
,
w_shape
,
w_shape
,
dtype
=
torch
.
float16
,
y_stride
,
w_dtype
=
torch
.
float16
,
x_stride
,
):
dtype
=
torch
.
float16
,
print
(
w_dtype
=
torch
.
float16
):
f
"Testing RMS_Norm on
{
torch_device
}
with y_shape:
{
y_shape
}
x_shape:
{
x_shape
}
w_shape:
{
w_shape
}
"
print
(
f
"Testing RMS_Norm on
{
torch_device
}
with y_shape:
{
y_shape
}
x_shape:
{
x_shape
}
w_shape:
{
w_shape
}
"
f
" dtype:
{
dtype
}
w_dtype:
{
w_dtype
}
"
f
" dtype:
{
dtype
}
w_dtype:
{
w_dtype
}
"
)
)
y
=
torch
.
zeros
(
y_shape
,
dtype
=
dtype
).
to
(
torch_device
)
y
=
torch
.
zeros
(
y_shape
,
dtype
=
dtype
).
to
(
torch_device
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
dtype
).
to
(
torch_device
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
dtype
).
to
(
torch_device
)
...
@@ -58,93 +80,64 @@ def test(
...
@@ -58,93 +80,64 @@ def test(
eps
=
1e-5
eps
=
1e-5
ans
=
rms_norm
(
x
,
w
,
eps
)
ans
=
rms_norm
(
x
,
w
,
eps
)
y_tensor
=
to_tensor
(
y
,
lib
)
x
=
rearrange_if_needed
(
x
,
x_stride
)
x_tensor
=
to_tensor
(
x
,
lib
)
y
=
rearrange_if_needed
(
y
,
y_stride
)
w_tensor
=
to_tensor
(
w
,
lib
)
x_tensor
,
y_tensor
,
w_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
x
,
y
,
w
]]
descriptor
=
infiniopRMSNormDescriptor_t
()
descriptor
=
infiniopRMSNormDescriptor_t
()
w_dataType
=
0
if
w_dtype
==
torch
.
float16
else
1
w_dataType
=
0
if
w_dtype
==
torch
.
float16
else
1
check_error
(
check_error
(
lib
.
infiniopCreateRMSNormDescriptor
(
lib
.
infiniopCreateRMSNormDescriptor
(
handle
,
handle
,
ctypes
.
byref
(
descriptor
),
y_tensor
.
descriptor
,
x_tensor
.
descriptor
,
ctypes
.
byref
(
descriptor
),
w_tensor
.
descriptor
,
eps
y_tensor
.
descriptor
,
x_tensor
.
descriptor
,
w_tensor
.
descriptor
,
eps
,
)
)
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor
.
descriptor
.
contents
.
invalidate
()
for
tensor
in
[
x_tensor
,
y_tensor
,
w_tensor
]:
y_tensor
.
descriptor
.
contents
.
invalidate
()
tensor
.
descriptor
.
contents
.
invalidate
()
w_tensor
.
descriptor
.
contents
.
invalidate
()
workspace_size
=
c_uint64
(
0
)
workspace_size
=
c_uint64
(
0
)
check_error
(
check_error
(
lib
.
infiniopGetRMSNormWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
))
lib
.
infiniopGetRMSNormWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
)
workspace
=
create_workspace
(
workspace_size
.
value
,
y
.
device
)
workspace
=
create_workspace
(
workspace_size
.
value
,
y
.
device
)
check_error
(
def
lib_rms_norm
():
lib
.
infiniopRMSNorm
(
check_error
(
descriptor
,
lib
.
infiniopRMSNorm
(
workspace
.
data_ptr
()
if
workspace
is
not
None
else
None
,
descriptor
,
workspace_size
.
value
,
workspace
.
data_ptr
()
if
workspace
is
not
None
else
None
,
y_tensor
.
data
,
workspace_size
.
value
,
x_tensor
.
data
,
y_tensor
.
data
,
w_tensor
.
data
,
x_tensor
.
data
,
None
,
w_tensor
.
data
,
None
,
)
)
)
)
assert
torch
.
allclose
(
y
.
to
(
dtype
),
ans
.
to
(
dtype
),
atol
=
1e-3
,
rtol
=
1e-3
)
lib_rms_norm
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
y
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
y
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
rms_norm
(
x
,
w
,
eps
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_rms_norm
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
lib
.
infiniopDestroyRMSNormDescriptor
(
descriptor
))
check_error
(
lib
.
infiniopDestroyRMSNormDescriptor
(
descriptor
))
def
test_cpu
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CPU
handle
=
create_handle
(
lib
,
device
)
for
y_shape
,
x_shape
,
w_shape
,
dtype
,
w_dtype
in
test_cases
:
test
(
lib
,
handle
,
"cpu"
,
y_shape
,
x_shape
,
w_shape
,
dtype
,
w_dtype
)
destroy_handle
(
lib
,
handle
)
def
test_cuda
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CUDA
handle
=
create_handle
(
lib
,
device
)
for
y_shape
,
x_shape
,
w_shape
,
dtype
,
w_dtype
in
test_cases
:
test
(
lib
,
handle
,
"cuda"
,
y_shape
,
x_shape
,
w_shape
,
dtype
,
w_dtype
)
destroy_handle
(
lib
,
handle
)
def
test_bang
(
lib
,
test_cases
):
import
torch_mlu
device
=
DeviceEnum
.
DEVICE_BANG
handle
=
create_handle
(
lib
,
device
)
for
y_shape
,
x_shape
,
w_shape
,
dtype
,
w_dtype
in
test_cases
:
test
(
lib
,
handle
,
"mlu"
,
y_shape
,
x_shape
,
w_shape
,
dtype
,
w_dtype
)
destroy_handle
(
lib
,
handle
)
def
test_ascend
(
lib
,
test_cases
):
import
torch_npu
device
=
DeviceEnum
.
DEVICE_ASCEND
handle
=
create_handle
(
lib
,
device
)
for
y_shape
,
x_shape
,
w_shape
,
dtype
,
w_dtype
in
test_cases
:
test
(
lib
,
handle
,
"npu"
,
y_shape
,
x_shape
,
w_shape
,
dtype
,
w_dtype
)
destroy_handle
(
lib
,
handle
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_cases
=
[
# y_shape, x_shape, w_shape, dtype, w_dtype
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
torch
.
float16
,
torch
.
float16
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
torch
.
float16
,
torch
.
float32
),
]
args
=
get_args
()
args
=
get_args
()
lib
=
open_lib
()
lib
=
open_lib
()
lib
.
infiniopCreateRMSNormDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateRMSNormDescriptor
.
restype
=
c_int32
...
@@ -178,14 +171,16 @@ if __name__ == "__main__":
...
@@ -178,14 +171,16 @@ if __name__ == "__main__":
infiniopRMSNormDescriptor_t
,
infiniopRMSNormDescriptor_t
,
]
]
if
args
.
cpu
:
# Configure testing options
test_cpu
(
lib
,
test_cases
)
DEBUG
=
args
.
debug
if
args
.
cuda
:
PROFILE
=
args
.
profile
test_cuda
(
lib
,
test_cases
)
NUM_PRERUN
=
args
.
num_prerun
if
args
.
bang
:
NUM_ITERATIONS
=
args
.
num_iterations
test_bang
(
lib
,
test_cases
)
if
args
.
ascend
:
# Execute tests
test_ascend
(
lib
,
test_case
s
)
for
device
in
get_test_devices
(
arg
s
)
:
if
not
(
args
.
cpu
or
args
.
cuda
or
args
.
bang
or
args
.
ascend
):
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
test_cpu
(
lib
,
test_cases
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/rotary_embedding.py
View file @
ca2f34cf
import
torch
import
ctypes
import
ctypes
from
ctypes
import
POINTER
,
c_void_p
,
c_int32
,
c_uint64
,
Structure
,
byref
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_size_t
,
c_uint64
,
c_void_p
,
c_float
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
".."
)))
from
libinfiniop
import
(
from
libinfiniop
import
(
infiniopHandle_t
,
infiniopHandle_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
...
@@ -16,10 +13,33 @@ from libinfiniop import (
...
@@ -16,10 +13,33 @@ from libinfiniop import (
test_operator
,
test_operator
,
get_args
,
get_args
,
debug
,
debug
,
get_tolerance
,
profile_operation
,
profile_operation
,
InfiniDtype
,
)
)
import
torch
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES
=
[
# (t_shape, t_strides)
((
1
,
32
,
128
),
None
),
((
1
,
32
,
64
),
None
),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((
4
,
1
,
32
),
None
),
((
1
,
32
,
128
),
None
),
((
3
,
32
,
128
),
(
8000
,
200
,
1
)),
]
# Data types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
0
,
"rtol"
:
1e-2
},
torch
.
float32
:
{
"atol"
:
0
,
"rtol"
:
1e-3
},
}
DEBUG
=
False
DEBUG
=
False
PROFILE
=
False
PROFILE
=
False
...
@@ -27,6 +47,7 @@ NUM_PRERUN = 10
...
@@ -27,6 +47,7 @@ NUM_PRERUN = 10
NUM_ITERATIONS
=
1000
NUM_ITERATIONS
=
1000
class
RoPEDescriptor
(
Structure
):
class
RoPEDescriptor
(
Structure
):
_fields_
=
[(
"device"
,
c_int32
)]
_fields_
=
[(
"device"
,
c_int32
)]
...
@@ -75,13 +96,22 @@ def sin_cos_table(max_seq_len, dim, torch_device, theta):
...
@@ -75,13 +96,22 @@ def sin_cos_table(max_seq_len, dim, torch_device, theta):
return
torch
.
sin
(
angles
),
torch
.
cos
(
angles
)
return
torch
.
sin
(
angles
),
torch
.
cos
(
angles
)
def
test
(
lib
,
handle
,
torch_device
,
shape
,
strides
=
None
,
dtype
=
torch
.
float16
):
def
test
(
lib
,
handle
,
torch_device
,
shape
,
strides
=
None
,
dtype
=
torch
.
float16
):
print
(
print
(
f
"Testing Rotary Positional Embedding on
{
torch_device
}
with shape:
{
shape
}
strides:
{
strides
}
and dtype:
{
dtype
}
"
f
"Testing Rotary Positional Embedding on
{
torch_device
}
with shape:
{
shape
}
strides:
{
strides
}
and dtype:
{
dtype
}
"
)
)
t
=
torch
.
rand
(
shape
,
dtype
=
dtype
)
t
=
torch
.
rand
(
shape
,
dtype
=
dtype
)
t
=
rearrange_if_needed
(
t
,
strides
).
to
(
torch_device
)
t
=
rearrange_if_needed
(
t
,
strides
)
posTmp
=
torch
.
arange
(
0
,
t
.
shape
[
0
]).
to
(
torch_device
)
posTmp
=
torch
.
arange
(
0
,
t
.
shape
[
0
]).
to
(
torch_device
)
pos
=
torch
.
zeros
(
2
*
posTmp
.
shape
[
0
],
dtype
=
torch
.
int32
)
pos
=
torch
.
zeros
(
2
*
posTmp
.
shape
[
0
],
dtype
=
torch
.
int32
)
for
i
in
range
(
posTmp
.
shape
[
0
]):
for
i
in
range
(
posTmp
.
shape
[
0
]):
...
@@ -95,11 +125,12 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
...
@@ -95,11 +125,12 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
descriptor
=
infiniopRoPEDescriptor_t
()
descriptor
=
infiniopRoPEDescriptor_t
()
# 2x table length for test
# 2x table length for test
sin_table
,
cos_table
=
sin_cos_table
(
t
.
shape
[
0
]
*
2
,
t
.
shape
[
2
],
t
.
device
,
theta
)
sin_table
,
cos_table
=
sin_cos_table
(
t
.
shape
[
0
]
*
2
,
t
.
shape
[
2
],
t
.
device
,
theta
)
t_tensor
=
to_tensor
(
t
,
lib
)
t_tensor
,
sin_table_tensor
,
cos_table_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
t
,
sin_table
,
cos_table
]]
pos_tensor
=
to_tensor
(
pos
[:
t
.
shape
[
0
]],
lib
)
pos_tensor
=
to_tensor
(
pos
[:
t
.
shape
[
0
]],
lib
)
pos_tensor
.
descriptor
.
contents
.
dtype
=
InfiniDtype
.
U64
pos_tensor
.
descriptor
.
contents
.
dtype
=
InfiniDtype
.
U64
sin_table_tensor
=
to_tensor
(
sin_table
,
lib
)
cos_table_tensor
=
to_tensor
(
cos_table
,
lib
)
if
torch_device
==
"npu"
:
if
torch_device
==
"npu"
:
torch
.
npu
.
synchronize
()
torch
.
npu
.
synchronize
()
...
@@ -116,10 +147,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
...
@@ -116,10 +147,8 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
t_tensor
.
descriptor
.
contents
.
invalidate
()
for
tensor
in
[
t_tensor
,
pos_tensor
,
sin_table_tensor
,
cos_table_tensor
]:
pos_tensor
.
descriptor
.
contents
.
invalidate
()
tensor
.
descriptor
.
contents
.
invalidate
()
sin_table_tensor
.
descriptor
.
contents
.
invalidate
()
cos_table_tensor
.
descriptor
.
contents
.
invalidate
()
workspace_size
=
c_uint64
(
0
)
workspace_size
=
c_uint64
(
0
)
check_error
(
check_error
(
...
@@ -142,9 +171,11 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
...
@@ -142,9 +171,11 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
)
)
lib_rope
()
lib_rope
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
if
DEBUG
:
debug
(
t
,
ans
,
atol
=
1e-4
,
rtol
=
1e-2
)
debug
(
t
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
t
,
ans
,
atol
=
1e-4
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
t
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
if
PROFILE
:
if
PROFILE
:
profile_operation
(
profile_operation
(
"PyTorch"
,
"PyTorch"
,
...
@@ -161,17 +192,6 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
...
@@ -161,17 +192,6 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_cases
=
[
# (t_shape, t_strides)
((
1
,
32
,
128
),
None
),
((
1
,
32
,
64
),
None
),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((
4
,
1
,
32
),
None
),
((
1
,
32
,
128
),
None
),
((
3
,
32
,
128
),
(
8000
,
200
,
1
)),
]
test_dtypes
=
[
torch
.
float16
]
args
=
get_args
()
args
=
get_args
()
lib
=
open_lib
()
lib
=
open_lib
()
lib
.
infiniopCreateRoPEDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateRoPEDescriptor
.
restype
=
c_int32
...
@@ -211,5 +231,5 @@ if __name__ == "__main__":
...
@@ -211,5 +231,5 @@ if __name__ == "__main__":
# Execute tests
# Execute tests
for
device
in
get_test_devices
(
args
):
for
device
in
get_test_devices
(
args
):
test_operator
(
lib
,
device
,
test
,
test_cases
,
test_dtypes
)
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/swiglu.py
View file @
ca2f34cf
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_uint64
,
c_void_p
import
torch
import
ctypes
import
ctypes
import
sys
from
ctypes
import
POINTER
,
Structure
,
c_int32
,
c_size_t
,
c_uint64
,
c_void_p
,
c_float
import
os
from
libinfiniop
import
(
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
".."
,
".."
)))
from
operatorspy
import
(
open_lib
,
to_tensor
,
CTensor
,
DeviceEnum
,
infiniopHandle_t
,
infiniopHandle_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
create_handle
,
open_lib
,
destroy_handle
,
to_tensor
,
get_test_devices
,
check_error
,
check_error
,
rearrange_tensor
,
rearrange_if_needed
,
create_workspace
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
)
)
from
operatorspy.tests.test_utils
import
get_args
# ==============================================================================
import
torch
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES
=
[
# shape, a_stride, b_stride, c_stride
((
13
,
4
),
None
,
None
,
None
),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
)),
((
13
,
4
,
4
),
None
,
None
,
None
),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
(
20
,
4
,
1
)),
((
16
,
5632
),
None
,
None
,
None
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
)),
((
4
,
4
,
5632
),
None
,
None
,
None
),
((
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
)),
]
# Data types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
'atol'
:
0
,
'rtol'
:
1e-2
},
torch
.
float32
:
{
'atol'
:
0
,
'rtol'
:
1e-3
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
class
SwiGLUDescriptor
(
Structure
):
class
SwiGLUDescriptor
(
Structure
):
_fields_
=
[(
"device"
,
c_int32
)]
_fields_
=
[(
"device"
,
c_int32
)]
...
@@ -51,20 +76,18 @@ def test_out_of_place(
...
@@ -51,20 +76,18 @@ def test_out_of_place(
b
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
b
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
c
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
c
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
if
a_stride
is
not
None
:
a
=
rearrange_tensor
(
a
,
a_stride
)
if
b_stride
is
not
None
:
b
=
rearrange_tensor
(
b
,
b_stride
)
if
c_stride
is
not
None
:
c
=
rearrange_tensor
(
c
,
c_stride
)
ans
=
swiglu
(
a
,
b
)
ans
=
swiglu
(
a
,
b
)
a
,
b
,
c
=
[
rearrange_if_needed
(
tensor
,
stride
)
for
tensor
,
stride
in
zip
([
a
,
b
,
c
],
[
a_stride
,
b_stride
,
c_stride
])
]
a_tensor
,
b_tensor
,
c_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
a
,
b
,
c
]]
if
sync
is
not
None
:
if
sync
is
not
None
:
sync
()
sync
()
a_tensor
=
to_tensor
(
a
,
lib
)
b_tensor
=
to_tensor
(
b
,
lib
)
c_tensor
=
to_tensor
(
c
,
lib
)
descriptor
=
infiniopSwiGLUDescriptor_t
()
descriptor
=
infiniopSwiGLUDescriptor_t
()
check_error
(
check_error
(
lib
.
infiniopCreateSwiGLUDescriptor
(
lib
.
infiniopCreateSwiGLUDescriptor
(
...
@@ -77,19 +100,33 @@ def test_out_of_place(
...
@@ -77,19 +100,33 @@ def test_out_of_place(
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
a_tensor
.
descriptor
.
contents
.
invalidate
()
for
tensor
in
[
a_tensor
,
b_tensor
,
c_tensor
]:
b_tensor
.
descriptor
.
contents
.
invalidate
()
tensor
.
descriptor
.
contents
.
invalidate
()
c_tensor
.
descriptor
.
contents
.
invalidate
()
def
lib_swiglu
():
check_error
(
check_error
(
lib
.
infiniopSwiGLU
(
lib
.
infiniopSwiGLU
(
descriptor
,
c_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
None
descriptor
,
c_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
None
)
)
)
)
lib_swiglu
(
)
assert
torch
.
allclose
(
c
,
ans
,
atol
=
1e-4
,
rtol
=
1e-2
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
c
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
c
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
print
(
"out-of-place Test passed!"
)
print
(
"out-of-place Test passed!"
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
swiglu
(
a
,
b
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_swiglu
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
lib
.
infiniopDestroySwiGLUDescriptor
(
descriptor
))
check_error
(
lib
.
infiniopDestroySwiGLUDescriptor
(
descriptor
))
...
@@ -106,18 +143,19 @@ def test_in_place1(
...
@@ -106,18 +143,19 @@ def test_in_place1(
a
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
a
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
b
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
b
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
if
a_stride
is
not
None
:
a
=
rearrange_tensor
(
a
,
a_stride
)
if
b_stride
is
not
None
:
b
=
rearrange_tensor
(
b
,
b_stride
)
ans
=
swiglu
(
a
,
b
)
ans
=
swiglu
(
a
,
b
)
if
sync
is
not
None
:
if
sync
is
not
None
:
sync
()
sync
()
a_tensor
=
to_tensor
(
a
,
lib
)
a
,
b
=
[
b_tensor
=
to_tensor
(
b
,
lib
)
rearrange_if_needed
(
tensor
,
stride
)
for
tensor
,
stride
in
zip
([
a
,
b
],
[
a_stride
,
b_stride
])
]
a_tensor
,
b_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
a
,
b
]]
descriptor
=
infiniopSwiGLUDescriptor_t
()
descriptor
=
infiniopSwiGLUDescriptor_t
()
check_error
(
check_error
(
lib
.
infiniopCreateSwiGLUDescriptor
(
lib
.
infiniopCreateSwiGLUDescriptor
(
handle
,
handle
,
...
@@ -129,18 +167,27 @@ def test_in_place1(
...
@@ -129,18 +167,27 @@ def test_in_place1(
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
a_tensor
.
descriptor
.
contents
.
invalidate
()
for
tensor
in
[
a_tensor
,
b_tensor
]:
b_tensor
.
descriptor
.
contents
.
invalidate
()
tensor
.
descriptor
.
contents
.
invalidate
()
def
lib_swiglu
():
check_error
(
check_error
(
lib
.
infiniopSwiGLU
(
lib
.
infiniopSwiGLU
(
descriptor
,
a_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
None
descriptor
,
a_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
None
)
)
)
)
lib_swiglu
(
)
assert
torch
.
allclose
(
a
,
ans
,
atol
=
1e-4
,
rtol
=
1e-2
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
a
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
a
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
print
(
"in-place1 Test passed!"
)
print
(
"in-place1 Test passed!"
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
swiglu
(
a
,
b
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_swiglu
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
lib
.
infiniopDestroySwiGLUDescriptor
(
descriptor
))
check_error
(
lib
.
infiniopDestroySwiGLUDescriptor
(
descriptor
))
...
@@ -157,17 +204,17 @@ def test_in_place2(
...
@@ -157,17 +204,17 @@ def test_in_place2(
a
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
a
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
b
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
b
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
if
a_stride
is
not
None
:
a
=
rearrange_tensor
(
a
,
a_stride
)
if
b_stride
is
not
None
:
b
=
rearrange_tensor
(
b
,
b_stride
)
ans
=
swiglu
(
a
,
b
)
ans
=
swiglu
(
a
,
b
)
if
sync
is
not
None
:
if
sync
is
not
None
:
sync
()
sync
()
a_tensor
=
to_tensor
(
a
,
lib
)
a
,
b
=
[
b_tensor
=
to_tensor
(
b
,
lib
)
rearrange_if_needed
(
tensor
,
stride
)
for
tensor
,
stride
in
zip
([
a
,
b
],
[
a_stride
,
b_stride
])
]
a_tensor
,
b_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
a
,
b
]]
descriptor
=
infiniopSwiGLUDescriptor_t
()
descriptor
=
infiniopSwiGLUDescriptor_t
()
check_error
(
check_error
(
lib
.
infiniopCreateSwiGLUDescriptor
(
lib
.
infiniopCreateSwiGLUDescriptor
(
...
@@ -180,100 +227,42 @@ def test_in_place2(
...
@@ -180,100 +227,42 @@ def test_in_place2(
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
a_tensor
.
descriptor
.
contents
.
invalidate
()
for
tensor
in
[
a_tensor
,
b_tensor
]:
b_tensor
.
descriptor
.
contents
.
invalidate
()
tensor
.
descriptor
.
contents
.
invalidate
()
check_error
(
def
lib_swiglu
():
lib
.
infiniopSwiGLU
(
check_error
(
descriptor
,
b_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
None
lib
.
infiniopSwiGLU
(
descriptor
,
b_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
None
)
)
)
)
lib_swiglu
()
assert
torch
.
allclose
(
b
,
ans
,
atol
=
1e-4
,
rtol
=
1e-2
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
b
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
b
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
print
(
"in-place2 Test passed!"
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
swiglu
(
a
,
b
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_swiglu
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
lib
.
infiniopDestroySwiGLUDescriptor
(
descriptor
))
check_error
(
lib
.
infiniopDestroySwiGLUDescriptor
(
descriptor
))
def
test_cpu
(
lib
,
test_cases
):
def
test
(
lib
,
handle
,
torch_device
,
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
,
sync
=
None
):
device
=
DeviceEnum
.
DEVICE_CPU
test_out_of_place
(
handle
=
create_handle
(
lib
,
device
)
lib
,
handle
,
torch_device
,
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
,
sync
)
for
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
in
test_cases
:
test_in_place1
(
lib
,
handle
,
torch_device
,
shape
,
a_stride
,
b_stride
,
dtype
,
sync
)
test_out_of_place
(
test_in_place2
(
lib
,
handle
,
torch_device
,
shape
,
a_stride
,
b_stride
,
dtype
,
sync
)
lib
,
handle
,
"cpu"
,
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
)
test_in_place1
(
lib
,
handle
,
"cpu"
,
shape
,
a_stride
,
b_stride
,
dtype
)
test_in_place2
(
lib
,
handle
,
"cpu"
,
shape
,
a_stride
,
b_stride
,
dtype
)
destroy_handle
(
lib
,
handle
)
def
test_cuda
(
lib
,
test_cases
):
device
=
DeviceEnum
.
DEVICE_CUDA
handle
=
create_handle
(
lib
,
device
)
for
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
in
test_cases
:
test_out_of_place
(
lib
,
handle
,
"cuda"
,
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
)
test_in_place1
(
lib
,
handle
,
"cuda"
,
shape
,
a_stride
,
b_stride
,
dtype
)
test_in_place2
(
lib
,
handle
,
"cuda"
,
shape
,
a_stride
,
b_stride
,
dtype
)
destroy_handle
(
lib
,
handle
)
def
test_bang
(
lib
,
test_cases
):
import
torch_mlu
device
=
DeviceEnum
.
DEVICE_BANG
handle
=
create_handle
(
lib
,
device
)
for
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
in
test_cases
:
test_out_of_place
(
lib
,
handle
,
"mlu"
,
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
)
test_in_place1
(
lib
,
handle
,
"mlu"
,
shape
,
a_stride
,
b_stride
,
dtype
)
test_in_place2
(
lib
,
handle
,
"mlu"
,
shape
,
a_stride
,
b_stride
,
dtype
)
destroy_handle
(
lib
,
handle
)
def
test_ascend
(
lib
,
test_cases
):
import
torch_npu
device
=
DeviceEnum
.
DEVICE_ASCEND
handle
=
create_handle
(
lib
,
device
)
for
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
in
test_cases
:
test_out_of_place
(
lib
,
handle
,
"npu"
,
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
,
torch
.
npu
.
synchronize
,
)
test_in_place1
(
lib
,
handle
,
"npu"
,
shape
,
a_stride
,
b_stride
,
dtype
,
torch
.
npu
.
synchronize
)
test_in_place2
(
lib
,
handle
,
"npu"
,
shape
,
a_stride
,
b_stride
,
dtype
,
torch
.
npu
.
synchronize
)
destroy_handle
(
lib
,
handle
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_cases
=
[
# shape, a_stride, b_stride, c_stride, dtype
((
13
,
4
),
None
,
None
,
None
,
torch
.
float16
),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
),
torch
.
float16
),
((
16
,
5632
),
None
,
None
,
None
,
torch
.
float16
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
),
torch
.
float16
),
]
args
=
get_args
()
args
=
get_args
()
lib
=
open_lib
()
lib
=
open_lib
()
...
@@ -299,13 +288,13 @@ if __name__ == "__main__":
...
@@ -299,13 +288,13 @@ if __name__ == "__main__":
lib
.
infiniopDestroySwiGLUDescriptor
.
argtypes
=
[
lib
.
infiniopDestroySwiGLUDescriptor
.
argtypes
=
[
infiniopSwiGLUDescriptor_t
,
infiniopSwiGLUDescriptor_t
,
]
]
# Configure testing options
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
if
args
.
cpu
:
test_cpu
(
lib
,
test_cases
)
if
args
.
cuda
:
test_cuda
(
lib
,
test_cases
)
if
args
.
bang
:
test_bang
(
lib
,
test_cases
)
if
args
.
ascend
:
test_ascend
(
lib
,
test_cases
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment