Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
04aa18f6
Commit
04aa18f6
authored
Feb 24, 2025
by
xgqdut2016
Browse files
issue/66: modified format
parent
ca2f34cf
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
213 additions
and
285 deletions
+213
-285
test/infiniop/causal_softmax.py
test/infiniop/causal_softmax.py
+24
-33
test/infiniop/random_sample.py
test/infiniop/random_sample.py
+43
-21
test/infiniop/rearrange.py
test/infiniop/rearrange.py
+17
-22
test/infiniop/rms_norm.py
test/infiniop/rms_norm.py
+34
-29
test/infiniop/rotary_embedding.py
test/infiniop/rotary_embedding.py
+23
-24
test/infiniop/swiglu.py
test/infiniop/swiglu.py
+72
-156
No files found.
test/infiniop/causal_softmax.py
View file @
04aa18f6
...
@@ -23,22 +23,20 @@ from libinfiniop import (
...
@@ -23,22 +23,20 @@ from libinfiniop import (
# These are not meant to be imported from other modules
# These are not meant to be imported from other modules
_TEST_CASES
=
[
_TEST_CASES
=
[
# x_shape, x_stride
# x_shape, x_stride
((
32
,
512
),
None
),
((
32
,
512
),
None
),
((
32
,
512
),
(
1024
,
1
)),
((
32
,
512
),
(
1024
,
1
)),
((
32
,
5
,
5
),
None
),
((
32
,
5
,
5
),
None
),
((
32
,
20
,
512
),
None
),
((
32
,
20
,
512
),
None
),
((
32
,
20
,
512
),
(
20480
,
512
,
1
)),
# Ascend 暂不支持非连续
((
32
,
20
,
512
),
(
20480
,
512
,
1
)),
# Ascend 暂不支持非连续
((
32
,
20
,
4
,
512
),
None
),
]
((
32
,
20
,
4
,
512
),
(
81920
,
2048
,
512
,
1
)),
]
# Data types used for testing
# Data types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
_TENSOR_DTYPES
=
[
torch
.
float16
]
# Tolerance map for different data types
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
'atol'
:
0
,
'rtol'
:
1e-2
},
torch
.
float16
:
{
"atol"
:
0
,
"rtol"
:
1e-2
},
torch
.
float32
:
{
'atol'
:
0
,
'rtol'
:
1e-3
},
}
}
DEBUG
=
False
DEBUG
=
False
...
@@ -46,6 +44,7 @@ PROFILE = False
...
@@ -46,6 +44,7 @@ PROFILE = False
NUM_PRERUN
=
10
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
NUM_ITERATIONS
=
1000
class
CausalSoftmaxDescriptor
(
Structure
):
class
CausalSoftmaxDescriptor
(
Structure
):
_fields_
=
[(
"device"
,
c_int32
)]
_fields_
=
[(
"device"
,
c_int32
)]
...
@@ -61,39 +60,29 @@ def causal_softmax(x):
...
@@ -61,39 +60,29 @@ def causal_softmax(x):
return
torch
.
nn
.
functional
.
softmax
(
masked
,
dim
=-
1
).
to
(
type
)
return
torch
.
nn
.
functional
.
softmax
(
masked
,
dim
=-
1
).
to
(
type
)
def
test
(
def
test
(
lib
,
handle
,
torch_device
,
x_shape
,
x_stride
=
None
,
dtype
=
torch
.
float16
):
lib
,
handle
,
torch_device
,
x_shape
,
x_stride
=
None
,
dtype
=
torch
.
float16
):
print
(
print
(
f
"Testing CausalSoftmax on
{
torch_device
}
with x_shape:
{
x_shape
}
x_stride:
{
x_stride
}
dtype:
{
dtype
}
"
f
"Testing CausalSoftmax on
{
torch_device
}
with x_shape:
{
x_shape
}
x_stride:
{
x_stride
}
dtype:
{
dtype
}
"
)
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
dtype
).
to
(
torch_device
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
dtype
).
to
(
torch_device
)
ans
=
causal_softmax
(
x
)
ans
=
causal_softmax
(
x
)
x
=
rearrange_if_needed
(
x
,
x_stride
)
x
=
rearrange_if_needed
(
x
,
x_stride
)
x_tensor
=
to_tensor
(
x
,
lib
)
x_tensor
=
to_tensor
(
x
,
lib
)
descriptor
=
infiniopCausalSoftmaxDescriptor_t
()
descriptor
=
infiniopCausalSoftmaxDescriptor_t
()
check_error
(
check_error
(
lib
.
infiniopCreateCausalSoftmaxDescriptor
(
lib
.
infiniopCreateCausalSoftmaxDescriptor
(
handle
,
handle
,
ctypes
.
byref
(
descriptor
),
x_tensor
.
descriptor
ctypes
.
byref
(
descriptor
),
x_tensor
.
descriptor
)
)
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor
.
descriptor
.
contents
.
invalidate
()
x_tensor
.
descriptor
.
contents
.
invalidate
()
workspace_size
=
c_uint64
(
0
)
workspace_size
=
c_uint64
(
0
)
check_error
(
check_error
(
lib
.
infiniopGetCausalSoftmaxWorkspaceSize
(
lib
.
infiniopGetCausalSoftmaxWorkspaceSize
(
...
@@ -101,6 +90,7 @@ def test(
...
@@ -101,6 +90,7 @@ def test(
)
)
)
)
workspace
=
create_workspace
(
workspace_size
.
value
,
x
.
device
)
workspace
=
create_workspace
(
workspace_size
.
value
,
x
.
device
)
def
lib_causal_softmax
():
def
lib_causal_softmax
():
check_error
(
check_error
(
lib
.
infiniopCausalSoftmax
(
lib
.
infiniopCausalSoftmax
(
...
@@ -111,8 +101,9 @@ def test(
...
@@ -111,8 +101,9 @@ def test(
None
,
None
,
)
)
)
)
lib_causal_softmax
()
lib_causal_softmax
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
if
DEBUG
:
debug
(
x
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
debug
(
x
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
...
@@ -128,24 +119,23 @@ def test(
...
@@ -128,24 +119,23 @@ def test(
check_error
(
lib
.
infiniopDestroyCausalSoftmaxDescriptor
(
descriptor
))
check_error
(
lib
.
infiniopDestroyCausalSoftmaxDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
get_args
()
args
=
get_args
()
lib
=
open_lib
()
lib
=
open_lib
()
lib
.
infiniopCreateCausalSoftmaxDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateCausalSoftmaxDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateCausalSoftmaxDescriptor
.
argtypes
=
[
lib
.
infiniopCreateCausalSoftmaxDescriptor
.
argtypes
=
[
infiniopHandle_t
,
infiniopHandle_t
,
POINTER
(
infiniopCausalSoftmaxDescriptor_t
),
POINTER
(
infiniopCausalSoftmaxDescriptor_t
),
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
]
]
lib
.
infiniopGetCausalSoftmaxWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetCausalSoftmaxWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetCausalSoftmaxWorkspaceSize
.
argtypes
=
[
lib
.
infiniopGetCausalSoftmaxWorkspaceSize
.
argtypes
=
[
infiniopCausalSoftmaxDescriptor_t
,
infiniopCausalSoftmaxDescriptor_t
,
POINTER
(
c_uint64
),
POINTER
(
c_uint64
),
]
]
lib
.
infiniopCausalSoftmax
.
restype
=
c_int32
lib
.
infiniopCausalSoftmax
.
restype
=
c_int32
lib
.
infiniopCausalSoftmax
.
argtypes
=
[
lib
.
infiniopCausalSoftmax
.
argtypes
=
[
infiniopCausalSoftmaxDescriptor_t
,
infiniopCausalSoftmaxDescriptor_t
,
...
@@ -154,18 +144,19 @@ if __name__ == "__main__":
...
@@ -154,18 +144,19 @@ if __name__ == "__main__":
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
]
lib
.
infiniopDestroyCausalSoftmaxDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyCausalSoftmaxDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyCausalSoftmaxDescriptor
.
argtypes
=
[
lib
.
infiniopDestroyCausalSoftmaxDescriptor
.
argtypes
=
[
infiniopCausalSoftmaxDescriptor_t
,
infiniopCausalSoftmaxDescriptor_t
,
]
]
# Configure testing options
# Configure testing options
DEBUG
=
args
.
debug
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
for
device
in
get_test_devices
(
args
):
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/random_sample.py
View file @
04aa18f6
...
@@ -12,33 +12,40 @@ from libinfiniop import (
...
@@ -12,33 +12,40 @@ from libinfiniop import (
create_workspace
,
create_workspace
,
test_operator
,
test_operator
,
get_args
,
get_args
,
debug
,
debug
_all
,
get_tolerance
,
get_tolerance
,
profile_operation
,
profile_operation
,
synchronize_device
,
)
)
# ==============================================================================
# ==============================================================================
# Configuration (Internal Use Only)
# Configuration (Internal Use Only)
# ==============================================================================
# ==============================================================================
# These are not meant to be imported from other modules
# These are not meant to be imported from other modules
_TEST_CASES
=
[
_TEST_CASES
=
[
# voc, random_val, topp, topk, temperature
# voc, random_val, topp, topk, temperature
(
512
,
0.8
,
0.8
,
3
,
0.5
),
(
512
,
0.8
,
0.8
,
3
,
0.5
),
(
4096
,
0.05
,
0.9
,
5
,
1.0
),
(
4096
,
0.05
,
0.9
,
5
,
1.0
),
(
16384
,
0.15
,
0.85
,
10
,
2.0
),
(
16384
,
0.15
,
0.85
,
10
,
2.0
),
(
512
,
0.08
,
0
,
3
,
0.5
),
(
512
,
0.08
,
0
,
3
,
0.5
),
(
4096
,
0.5
,
0.9
,
1
,
1.0
),
(
4096
,
0.5
,
0.9
,
1
,
1.0
),
(
16384
,
0.15
,
0
,
1
,
2.0
),
(
16384
,
0.15
,
0
,
1
,
2.0
),
(
16384
,
0.15
,
0
,
1
,
2.0
),
(
16384
,
0.15
,
0
,
1
,
2.0
),
(
32000
,
0.08
,
0.8
,
50
,
1.0
),
(
32000
,
0.08
,
0.8
,
50
,
1.0
),
(
32000
,
0.08
,
1.0
,
25
,
1.0
),
(
32000
,
0.08
,
1.0
,
25
,
1.0
),
# (119696, 0.01, 1.0, 100, 1.0),
# (119696, 0.01, 1.0, 100, 1.0),
]
]
# Data types used for testing
# Data types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
_TENSOR_DTYPES
=
[
torch
.
float16
]
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
0
,
"rtol"
:
0
},
}
DEBUG
=
False
PROFILE
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
NUM_ITERATIONS
=
1000
...
@@ -113,6 +120,7 @@ def test(
...
@@ -113,6 +120,7 @@ def test(
x_dtype
=
torch
.
float16
,
x_dtype
=
torch
.
float16
,
):
):
print
(
f
"Testing RandomSample on
{
torch_device
}
with voc:
{
voc
}
dtype:
{
x_dtype
}
"
)
print
(
f
"Testing RandomSample on
{
torch_device
}
with voc:
{
voc
}
dtype:
{
x_dtype
}
"
)
data
=
torch
.
arange
(
voc
).
float
()
*
0.0001
data
=
torch
.
arange
(
voc
).
float
()
*
0.0001
_perm
=
torch
.
randperm
(
voc
)
_perm
=
torch
.
randperm
(
voc
)
data
=
data
[
_perm
].
to
(
x_dtype
).
to
(
torch_device
)
data
=
data
[
_perm
].
to
(
x_dtype
).
to
(
torch_device
)
...
@@ -122,9 +130,11 @@ def test(
...
@@ -122,9 +130,11 @@ def test(
)
)
else
:
else
:
ans
=
random_sample_0
(
data
)
ans
=
random_sample_0
(
data
)
indices
=
torch
.
zeros
([
1
],
dtype
=
torch
.
int64
).
to
(
torch_device
)
indices
=
torch
.
zeros
([
1
],
dtype
=
torch
.
int64
).
to
(
torch_device
)
x_tensor
=
to_tensor
(
data
,
lib
)
indices_tensor
=
to_tensor
(
indices
,
lib
)
x_tensor
,
indices_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
data
,
indices
]]
indices_tensor
.
descriptor
.
contents
.
dt
=
U64
# treat int64 as uint64
indices_tensor
.
descriptor
.
contents
.
dt
=
U64
# treat int64 as uint64
descriptor
=
infiniopRandomSampleDescriptor_t
()
descriptor
=
infiniopRandomSampleDescriptor_t
()
...
@@ -148,7 +158,7 @@ def test(
...
@@ -148,7 +158,7 @@ def test(
)
)
)
)
workspace
=
create_workspace
(
workspace_size
.
value
,
torch_device
)
workspace
=
create_workspace
(
workspace_size
.
value
,
torch_device
)
def
lib_random_sample
():
def
lib_random_sample
():
check_error
(
check_error
(
lib
.
infiniopRandomSample
(
lib
.
infiniopRandomSample
(
...
@@ -164,11 +174,21 @@ def test(
...
@@ -164,11 +174,21 @@ def test(
None
,
None
,
)
)
)
)
if
torch_device
==
"npu"
:
torch
.
npu
.
synchronize
()
if
torch_device
==
"npu"
:
synchronize_device
(
torch_device
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug_all
(
(
indices
[
0
].
type
(
ans
.
dtype
),
data
[
indices
[
0
]]),
(
ans
,
data
[
ans
]),
"or"
,
atol
=
atol
,
rtol
=
rtol
,
)
assert
indices
[
0
].
type
(
ans
.
dtype
)
==
ans
or
data
[
ans
]
==
data
[
indices
[
0
]]
assert
indices
[
0
].
type
(
ans
.
dtype
)
==
ans
or
data
[
ans
]
==
data
[
indices
[
0
]]
# Profiling workflow
# Profiling workflow
if
PROFILE
:
if
PROFILE
:
# fmt: off
# fmt: off
...
@@ -184,23 +204,23 @@ def test(
...
@@ -184,23 +204,23 @@ def test(
check_error
(
lib
.
infiniopDestroyRandomSampleDescriptor
(
descriptor
))
check_error
(
lib
.
infiniopDestroyRandomSampleDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
get_args
()
args
=
get_args
()
lib
=
open_lib
()
lib
=
open_lib
()
lib
.
infiniopCreateRandomSampleDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateRandomSampleDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateRandomSampleDescriptor
.
argtypes
=
[
lib
.
infiniopCreateRandomSampleDescriptor
.
argtypes
=
[
infiniopHandle_t
,
infiniopHandle_t
,
POINTER
(
infiniopRandomSampleDescriptor_t
),
POINTER
(
infiniopRandomSampleDescriptor_t
),
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
]
]
lib
.
infiniopGetRandomSampleWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetRandomSampleWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetRandomSampleWorkspaceSize
.
argtypes
=
[
lib
.
infiniopGetRandomSampleWorkspaceSize
.
argtypes
=
[
infiniopRandomSampleDescriptor_t
,
infiniopRandomSampleDescriptor_t
,
POINTER
(
c_uint64
),
POINTER
(
c_uint64
),
]
]
lib
.
infiniopRandomSample
.
restype
=
c_int32
lib
.
infiniopRandomSample
.
restype
=
c_int32
lib
.
infiniopRandomSample
.
argtypes
=
[
lib
.
infiniopRandomSample
.
argtypes
=
[
infiniopRandomSampleDescriptor_t
,
infiniopRandomSampleDescriptor_t
,
...
@@ -214,11 +234,13 @@ if __name__ == "__main__":
...
@@ -214,11 +234,13 @@ if __name__ == "__main__":
c_float
,
c_float
,
c_void_p
,
c_void_p
,
]
]
lib
.
infiniopDestroyRandomSampleDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyRandomSampleDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyRandomSampleDescriptor
.
argtypes
=
[
lib
.
infiniopDestroyRandomSampleDescriptor
.
argtypes
=
[
infiniopRandomSampleDescriptor_t
,
infiniopRandomSampleDescriptor_t
,
]
]
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
NUM_ITERATIONS
=
args
.
num_iterations
...
...
test/infiniop/rearrange.py
View file @
04aa18f6
...
@@ -23,13 +23,13 @@ from libinfiniop import (
...
@@ -23,13 +23,13 @@ from libinfiniop import (
# These are not meant to be imported from other modules
# These are not meant to be imported from other modules
_TEST_CASES
=
[
_TEST_CASES
=
[
# ((src_shape, src_stride), (dst_shape, dst_stride))
# ((src_shape, src_stride), (dst_shape, dst_stride))
(((
2
,
4
,
32
),
None
),
((
2
,
4
,
32
),
(
256
,
64
,
1
))),
(((
2
,
4
,
32
),
None
),
((
2
,
4
,
32
),
(
256
,
64
,
1
))),
(((
32
,
6
,
64
),
(
64
,
2560
,
1
)),
((
32
,
6
,
64
),
None
)),
(((
32
,
6
,
64
),
(
64
,
2560
,
1
)),
((
32
,
6
,
64
),
None
)),
(((
4
,
6
,
64
),
(
64
,
2560
,
1
)),
((
4
,
6
,
64
),
(
131072
,
64
,
1
))),
(((
4
,
6
,
64
),
(
64
,
2560
,
1
)),
((
4
,
6
,
64
),
(
131072
,
64
,
1
))),
(((
1
,
32
,
64
),
(
2048
,
64
,
1
)),
((
1
,
32
,
64
),
(
2048
,
64
,
1
))),
(((
1
,
32
,
64
),
(
2048
,
64
,
1
)),
((
1
,
32
,
64
),
(
2048
,
64
,
1
))),
(((
32
,
1
,
64
),
(
64
,
2560
,
1
)),
((
32
,
1
,
64
),
(
64
,
64
,
1
))),
(((
32
,
1
,
64
),
(
64
,
2560
,
1
)),
((
32
,
1
,
64
),
(
64
,
64
,
1
))),
(((
4
,
1
,
64
),
(
64
,
2560
,
1
)),
((
4
,
1
,
64
),
(
64
,
11264
,
1
))),
(((
4
,
1
,
64
),
(
64
,
2560
,
1
)),
((
4
,
1
,
64
),
(
64
,
11264
,
1
))),
(((
64
,),
(
1
,)),
((
64
,),
(
1
,))),
(((
64
,),
(
1
,)),
((
64
,),
(
1
,))),
]
]
# Data types used for testing
# Data types used for testing
...
@@ -37,8 +37,8 @@ _TENSOR_DTYPES = [torch.float16, torch.float32]
...
@@ -37,8 +37,8 @@ _TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
0
,
"rtol"
:
1e-3
},
torch
.
float16
:
{
"atol"
:
0
,
"rtol"
:
0
},
torch
.
float32
:
{
"atol"
:
0
,
"rtol"
:
1e-3
},
torch
.
float32
:
{
"atol"
:
0
,
"rtol"
:
0
},
}
}
DEBUG
=
False
DEBUG
=
False
...
@@ -47,7 +47,6 @@ NUM_PRERUN = 10
...
@@ -47,7 +47,6 @@ NUM_PRERUN = 10
NUM_ITERATIONS
=
1000
NUM_ITERATIONS
=
1000
class
RerrangeDescriptor
(
Structure
):
class
RerrangeDescriptor
(
Structure
):
_fields_
=
[(
"device"
,
c_int32
)]
_fields_
=
[(
"device"
,
c_int32
)]
...
@@ -68,16 +67,16 @@ def test(
...
@@ -68,16 +67,16 @@ def test(
print
(
print
(
f
"Testing Rerrange on
{
torch_device
}
with x_shape:
{
x_shape
}
x_stride:
{
x_stride
}
y_shape:
{
y_shape
}
y_stride:
{
y_stride
}
x_dtype:
{
x_dtype
}
"
f
"Testing Rerrange on
{
torch_device
}
with x_shape:
{
x_shape
}
x_stride:
{
x_stride
}
y_shape:
{
y_shape
}
y_stride:
{
y_stride
}
x_dtype:
{
x_dtype
}
"
)
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
x_dtype
).
to
(
torch_device
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
x_dtype
).
to
(
torch_device
)
y
=
torch
.
zeros
(
y_shape
,
dtype
=
x_dtype
).
to
(
torch_device
)
y
=
torch
.
zeros
(
y_shape
,
dtype
=
x_dtype
).
to
(
torch_device
)
x
,
y
=
[
x
,
y
=
[
rearrange_if_needed
(
tensor
,
stride
)
rearrange_if_needed
(
tensor
,
stride
)
for
tensor
,
stride
in
zip
([
x
,
y
],
[
x_stride
,
y_stride
])
for
tensor
,
stride
in
zip
([
x
,
y
],
[
x_stride
,
y_stride
])
]
]
x_tensor
,
y_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
x
,
y
]]
x_tensor
,
y_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
x
,
y
]]
descriptor
=
infiniopRearrangeDescriptor_t
()
descriptor
=
infiniopRearrangeDescriptor_t
()
check_error
(
check_error
(
lib
.
infiniopCreateRearrangeDescriptor
(
lib
.
infiniopCreateRearrangeDescriptor
(
...
@@ -91,15 +90,11 @@ def test(
...
@@ -91,15 +90,11 @@ def test(
def
lib_rearrange
():
def
lib_rearrange
():
check_error
(
check_error
(
lib
.
infiniopRearrange
(
lib
.
infiniopRearrange
(
descriptor
,
y_tensor
.
data
,
x_tensor
.
data
,
None
)
descriptor
,
y_tensor
.
data
,
x_tensor
.
data
,
None
)
)
)
lib_rearrange
()
lib_rearrange
()
# Validate results
# Validate results
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
if
DEBUG
:
...
@@ -116,8 +111,6 @@ def test(
...
@@ -116,8 +111,6 @@ def test(
check_error
(
lib
.
infiniopDestroyRearrangeDescriptor
(
descriptor
))
check_error
(
lib
.
infiniopDestroyRearrangeDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
get_args
()
args
=
get_args
()
lib
=
open_lib
()
lib
=
open_lib
()
...
@@ -129,6 +122,7 @@ if __name__ == "__main__":
...
@@ -129,6 +122,7 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
]
]
lib
.
infiniopRearrange
.
restype
=
c_int32
lib
.
infiniopRearrange
.
restype
=
c_int32
lib
.
infiniopRearrange
.
argtypes
=
[
lib
.
infiniopRearrange
.
argtypes
=
[
infiniopRearrangeDescriptor_t
,
infiniopRearrangeDescriptor_t
,
...
@@ -136,9 +130,10 @@ if __name__ == "__main__":
...
@@ -136,9 +130,10 @@ if __name__ == "__main__":
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
]
lib
.
infiniopDestroyRearrangeDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyRearrangeDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyRearrangeDescriptor
.
argtypes
=
[
infiniopRearrangeDescriptor_t
]
lib
.
infiniopDestroyRearrangeDescriptor
.
argtypes
=
[
infiniopRearrangeDescriptor_t
]
# Configure testing options
# Configure testing options
DEBUG
=
args
.
debug
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
PROFILE
=
args
.
profile
...
...
test/infiniop/rms_norm.py
View file @
04aa18f6
...
@@ -23,21 +23,20 @@ from libinfiniop import (
...
@@ -23,21 +23,20 @@ from libinfiniop import (
# Configuration (Internal Use Only)
# Configuration (Internal Use Only)
# ==============================================================================
# ==============================================================================
# These are not meant to be imported from other modules
# These are not meant to be imported from other modules
_TEST_CASES
=
[
_TEST_CASES
=
[
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
torch
.
float32
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
torch
.
float32
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
torch
.
float16
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
,
torch
.
float16
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
torch
.
float32
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
torch
.
float32
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
torch
.
float16
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
),
torch
.
float16
),
]
]
# x types used for testing
# x types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
_TENSOR_DTYPES
=
[
torch
.
float16
]
# Tolerance map for different data types
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
0
,
"rtol"
:
1e-2
},
torch
.
float16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-3
},
torch
.
float32
:
{
"atol"
:
0
,
"rtol"
:
1e-3
},
}
}
DEBUG
=
False
DEBUG
=
False
...
@@ -45,12 +44,14 @@ PROFILE = False
...
@@ -45,12 +44,14 @@ PROFILE = False
NUM_PRERUN
=
10
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
NUM_ITERATIONS
=
1000
class
RMSNormDescriptor
(
Structure
):
class
RMSNormDescriptor
(
Structure
):
_fields_
=
[(
"device"
,
c_int32
)]
_fields_
=
[(
"device"
,
c_int32
)]
infiniopRMSNormDescriptor_t
=
POINTER
(
RMSNormDescriptor
)
infiniopRMSNormDescriptor_t
=
POINTER
(
RMSNormDescriptor
)
def
rms_norm
(
x
,
w
,
eps
):
def
rms_norm
(
x
,
w
,
eps
):
input_dtype
=
x
.
dtype
input_dtype
=
x
.
dtype
hidden_states
=
x
.
to
(
torch
.
float32
)
hidden_states
=
x
.
to
(
torch
.
float32
)
...
@@ -60,18 +61,21 @@ def rms_norm(x, w, eps):
...
@@ -60,18 +61,21 @@ def rms_norm(x, w, eps):
def
test
(
def
test
(
lib
,
lib
,
handle
,
handle
,
torch_device
,
torch_device
,
y_shape
,
y_shape
,
x_shape
,
x_shape
,
w_shape
,
w_shape
,
y_stride
,
y_stride
,
x_stride
,
x_stride
,
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
w_dtype
=
torch
.
float16
):
w_dtype
=
torch
.
float16
,
print
(
f
"Testing RMS_Norm on
{
torch_device
}
with y_shape:
{
y_shape
}
x_shape:
{
x_shape
}
w_shape:
{
w_shape
}
"
):
f
" dtype:
{
dtype
}
w_dtype:
{
w_dtype
}
"
)
print
(
f
"Testing RMS_Norm on
{
torch_device
}
with y_shape:
{
y_shape
}
x_shape:
{
x_shape
}
w_shape:
{
w_shape
}
"
f
" dtype:
{
dtype
}
w_dtype:
{
w_dtype
}
"
)
y
=
torch
.
zeros
(
y_shape
,
dtype
=
dtype
).
to
(
torch_device
)
y
=
torch
.
zeros
(
y_shape
,
dtype
=
dtype
).
to
(
torch_device
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
dtype
).
to
(
torch_device
)
x
=
torch
.
rand
(
x_shape
,
dtype
=
dtype
).
to
(
torch_device
)
...
@@ -80,18 +84,23 @@ def test(
...
@@ -80,18 +84,23 @@ def test(
eps
=
1e-5
eps
=
1e-5
ans
=
rms_norm
(
x
,
w
,
eps
)
ans
=
rms_norm
(
x
,
w
,
eps
)
x
=
rearrange_if_needed
(
x
,
x_stride
)
x
,
y
=
[
y
=
rearrange_if_needed
(
y
,
y_stride
)
rearrange_if_needed
(
tensor
,
stride
)
for
tensor
,
stride
in
zip
([
x
,
y
],
[
x_stride
,
y_stride
])
]
x_tensor
,
y_tensor
,
w_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
x
,
y
,
w
]]
x_tensor
,
y_tensor
,
w_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
x
,
y
,
w
]]
descriptor
=
infiniopRMSNormDescriptor_t
()
descriptor
=
infiniopRMSNormDescriptor_t
()
w_dataType
=
0
if
w_dtype
==
torch
.
float16
else
1
check_error
(
check_error
(
lib
.
infiniopCreateRMSNormDescriptor
(
lib
.
infiniopCreateRMSNormDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
y_tensor
.
descriptor
,
x_tensor
.
descriptor
,
handle
,
w_tensor
.
descriptor
,
eps
ctypes
.
byref
(
descriptor
),
y_tensor
.
descriptor
,
x_tensor
.
descriptor
,
w_tensor
.
descriptor
,
eps
,
)
)
)
)
...
@@ -101,11 +110,10 @@ def test(
...
@@ -101,11 +110,10 @@ def test(
workspace_size
=
c_uint64
(
0
)
workspace_size
=
c_uint64
(
0
)
check_error
(
check_error
(
lib
.
infiniopGetRMSNormWorkspaceSize
(
lib
.
infiniopGetRMSNormWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
))
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
)
workspace
=
create_workspace
(
workspace_size
.
value
,
y
.
device
)
workspace
=
create_workspace
(
workspace_size
.
value
,
y
.
device
)
def
lib_rms_norm
():
def
lib_rms_norm
():
check_error
(
check_error
(
lib
.
infiniopRMSNorm
(
lib
.
infiniopRMSNorm
(
...
@@ -134,12 +142,10 @@ def test(
...
@@ -134,12 +142,10 @@ def test(
check_error
(
lib
.
infiniopDestroyRMSNormDescriptor
(
descriptor
))
check_error
(
lib
.
infiniopDestroyRMSNormDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
get_args
()
args
=
get_args
()
lib
=
open_lib
()
lib
=
open_lib
()
lib
.
infiniopCreateRMSNormDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateRMSNormDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateRMSNormDescriptor
.
argtypes
=
[
lib
.
infiniopCreateRMSNormDescriptor
.
argtypes
=
[
infiniopHandle_t
,
infiniopHandle_t
,
...
@@ -166,6 +172,7 @@ if __name__ == "__main__":
...
@@ -166,6 +172,7 @@ if __name__ == "__main__":
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
]
lib
.
infiniopDestroyRMSNormDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyRMSNormDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyRMSNormDescriptor
.
argtypes
=
[
lib
.
infiniopDestroyRMSNormDescriptor
.
argtypes
=
[
infiniopRMSNormDescriptor_t
,
infiniopRMSNormDescriptor_t
,
...
@@ -182,5 +189,3 @@ if __name__ == "__main__":
...
@@ -182,5 +189,3 @@ if __name__ == "__main__":
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
test/infiniop/rotary_embedding.py
View file @
04aa18f6
...
@@ -15,6 +15,7 @@ from libinfiniop import (
...
@@ -15,6 +15,7 @@ from libinfiniop import (
debug
,
debug
,
get_tolerance
,
get_tolerance
,
profile_operation
,
profile_operation
,
synchronize_device
,
)
)
# ==============================================================================
# ==============================================================================
...
@@ -23,22 +24,21 @@ from libinfiniop import (
...
@@ -23,22 +24,21 @@ from libinfiniop import (
# These are not meant to be imported from other modules
# These are not meant to be imported from other modules
_TEST_CASES
=
[
_TEST_CASES
=
[
# (t_shape, t_strides)
# (t_shape, t_strides)
((
1
,
32
,
128
),
None
),
((
1
,
32
,
128
),
None
),
((
1
,
32
,
64
),
None
),
((
1
,
32
,
64
),
None
),
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
# 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((
4
,
1
,
32
),
None
),
((
4
,
1
,
32
),
None
),
((
1
,
32
,
128
),
None
),
((
1
,
32
,
128
),
None
),
((
3
,
32
,
128
),
(
8000
,
200
,
1
)),
((
3
,
32
,
128
),
(
8000
,
200
,
1
)),
]
]
# Data types used for testing
# Data types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
_TENSOR_DTYPES
=
[
torch
.
float16
]
# Tolerance map for different data types
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
"atol"
:
0
,
"rtol"
:
1e-2
},
torch
.
float16
:
{
"atol"
:
1e-4
,
"rtol"
:
1e-2
},
torch
.
float32
:
{
"atol"
:
0
,
"rtol"
:
1e-3
},
}
}
DEBUG
=
False
DEBUG
=
False
...
@@ -47,7 +47,6 @@ NUM_PRERUN = 10
...
@@ -47,7 +47,6 @@ NUM_PRERUN = 10
NUM_ITERATIONS
=
1000
NUM_ITERATIONS
=
1000
class
RoPEDescriptor
(
Structure
):
class
RoPEDescriptor
(
Structure
):
_fields_
=
[(
"device"
,
c_int32
)]
_fields_
=
[(
"device"
,
c_int32
)]
...
@@ -96,14 +95,7 @@ def sin_cos_table(max_seq_len, dim, torch_device, theta):
...
@@ -96,14 +95,7 @@ def sin_cos_table(max_seq_len, dim, torch_device, theta):
return
torch
.
sin
(
angles
),
torch
.
cos
(
angles
)
return
torch
.
sin
(
angles
),
torch
.
cos
(
angles
)
def
test
(
def
test
(
lib
,
handle
,
torch_device
,
shape
,
strides
=
None
,
dtype
=
torch
.
float16
):
lib
,
handle
,
torch_device
,
shape
,
strides
=
None
,
dtype
=
torch
.
float16
):
print
(
print
(
f
"Testing Rotary Positional Embedding on
{
torch_device
}
with shape:
{
shape
}
strides:
{
strides
}
and dtype:
{
dtype
}
"
f
"Testing Rotary Positional Embedding on
{
torch_device
}
with shape:
{
shape
}
strides:
{
strides
}
and dtype:
{
dtype
}
"
)
)
...
@@ -126,14 +118,15 @@ def test(
...
@@ -126,14 +118,15 @@ def test(
# 2x table length for test
# 2x table length for test
sin_table
,
cos_table
=
sin_cos_table
(
t
.
shape
[
0
]
*
2
,
t
.
shape
[
2
],
t
.
device
,
theta
)
sin_table
,
cos_table
=
sin_cos_table
(
t
.
shape
[
0
]
*
2
,
t
.
shape
[
2
],
t
.
device
,
theta
)
t_tensor
,
sin_table_tensor
,
cos_table_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
t
,
sin_table
,
cos_table
]]
t_tensor
,
sin_table_tensor
,
cos_table_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
t
,
sin_table
,
cos_table
]
]
pos_tensor
=
to_tensor
(
pos
[:
t
.
shape
[
0
]],
lib
)
pos_tensor
=
to_tensor
(
pos
[:
t
.
shape
[
0
]],
lib
)
pos_tensor
.
descriptor
.
contents
.
dtype
=
InfiniDtype
.
U64
pos_tensor
.
descriptor
.
contents
.
dtype
=
InfiniDtype
.
U64
if
torch_device
==
"npu"
:
if
torch_device
==
"npu"
:
torch
.
npu
.
synchronize
(
)
synchronize
_device
(
torch_device
)
check_error
(
check_error
(
lib
.
infiniopCreateRoPEDescriptor
(
lib
.
infiniopCreateRoPEDescriptor
(
...
@@ -171,11 +164,12 @@ def test(
...
@@ -171,11 +164,12 @@ def test(
)
)
lib_rope
()
lib_rope
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
if
DEBUG
:
debug
(
t
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
debug
(
t
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
t
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
t
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
if
PROFILE
:
if
PROFILE
:
profile_operation
(
profile_operation
(
"PyTorch"
,
"PyTorch"
,
...
@@ -194,6 +188,7 @@ def test(
...
@@ -194,6 +188,7 @@ def test(
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
get_args
()
args
=
get_args
()
lib
=
open_lib
()
lib
=
open_lib
()
lib
.
infiniopCreateRoPEDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateRoPEDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateRoPEDescriptor
.
argtypes
=
[
lib
.
infiniopCreateRoPEDescriptor
.
argtypes
=
[
infiniopHandle_t
,
infiniopHandle_t
,
...
@@ -203,11 +198,13 @@ if __name__ == "__main__":
...
@@ -203,11 +198,13 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
]
]
lib
.
infiniopGetRoPEWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetRoPEWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetRoPEWorkspaceSize
.
argtypes
=
[
lib
.
infiniopGetRoPEWorkspaceSize
.
argtypes
=
[
infiniopRoPEDescriptor_t
,
infiniopRoPEDescriptor_t
,
POINTER
(
c_uint64
),
POINTER
(
c_uint64
),
]
]
lib
.
infiniopRoPE
.
restype
=
c_int32
lib
.
infiniopRoPE
.
restype
=
c_int32
lib
.
infiniopRoPE
.
argtypes
=
[
lib
.
infiniopRoPE
.
argtypes
=
[
infiniopRoPEDescriptor_t
,
infiniopRoPEDescriptor_t
,
...
@@ -219,10 +216,12 @@ if __name__ == "__main__":
...
@@ -219,10 +216,12 @@ if __name__ == "__main__":
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
]
lib
.
infiniopDestroyRoPEDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyRoPEDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyRoPEDescriptor
.
argtypes
=
[
lib
.
infiniopDestroyRoPEDescriptor
.
argtypes
=
[
infiniopRoPEDescriptor_t
,
infiniopRoPEDescriptor_t
,
]
]
# Configure testing options
# Configure testing options
DEBUG
=
args
.
debug
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
PROFILE
=
args
.
profile
...
...
test/infiniop/swiglu.py
View file @
04aa18f6
...
@@ -16,29 +16,64 @@ from libinfiniop import (
...
@@ -16,29 +16,64 @@ from libinfiniop import (
get_tolerance
,
get_tolerance
,
profile_operation
,
profile_operation
,
)
)
from
enum
import
Enum
,
auto
# ==============================================================================
# ==============================================================================
# Configuration (Internal Use Only)
# Configuration (Internal Use Only)
# ==============================================================================
# ==============================================================================
# These are not meant to be imported from other modules
# These are not meant to be imported from other modules
_TEST_CASES
=
[
_TEST_CASES
=
[
# shape, a_stride, b_stride, c_stride
# shape, a_stride, b_stride, c_stride, inplace
((
13
,
4
),
None
,
None
,
None
),
((
13
,
4
),
None
,
None
,
None
,
Inplace
.
OUT_OF_PLACE
),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
)),
((
13
,
4
),
None
,
None
,
None
,
Inplace
.
INPLACE_A
),
((
13
,
4
,
4
),
None
,
None
,
None
),
((
13
,
4
),
None
,
None
,
None
,
Inplace
.
INPLACE_B
),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
(
20
,
4
,
1
)),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
),
Inplace
.
OUT_OF_PLACE
),
((
16
,
5632
),
None
,
None
,
None
),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
),
Inplace
.
INPLACE_A
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
)),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
),
Inplace
.
INPLACE_B
),
((
4
,
4
,
5632
),
None
,
None
,
None
),
((
13
,
4
,
4
),
None
,
None
,
None
,
Inplace
.
OUT_OF_PLACE
),
((
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
)),
((
13
,
4
,
4
),
None
,
None
,
None
,
Inplace
.
INPLACE_A
),
((
13
,
4
,
4
),
None
,
None
,
None
,
Inplace
.
INPLACE_B
),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
Inplace
.
OUT_OF_PLACE
),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
Inplace
.
INPLACE_A
),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
Inplace
.
INPLACE_B
),
((
16
,
5632
),
None
,
None
,
None
,
Inplace
.
OUT_OF_PLACE
),
((
16
,
5632
),
None
,
None
,
None
,
Inplace
.
INPLACE_A
),
((
16
,
5632
),
None
,
None
,
None
,
Inplace
.
INPLACE_B
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
),
Inplace
.
OUT_OF_PLACE
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
),
Inplace
.
INPLACE_A
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
),
Inplace
.
INPLACE_B
),
((
4
,
4
,
5632
),
None
,
None
,
None
,
Inplace
.
OUT_OF_PLACE
),
((
4
,
4
,
5632
),
None
,
None
,
None
,
Inplace
.
INPLACE_A
),
((
4
,
4
,
5632
),
None
,
None
,
None
,
Inplace
.
INPLACE_B
),
(
(
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
Inplace
.
OUT_OF_PLACE
,
),
(
(
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
Inplace
.
INPLACE_A
,
),
(
(
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
Inplace
.
INPLACE_B
,
),
]
]
# Data types used for testing
# Data types used for testing
_TENSOR_DTYPES
=
[
torch
.
float16
,
torch
.
float32
]
_TENSOR_DTYPES
=
[
torch
.
float16
]
# Tolerance map for different data types
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
_TOLERANCE_MAP
=
{
torch
.
float16
:
{
'atol'
:
0
,
'rtol'
:
1e-2
},
torch
.
float16
:
{
"atol"
:
1e-4
,
"rtol"
:
1e-2
},
torch
.
float32
:
{
'atol'
:
0
,
'rtol'
:
1e-3
},
}
}
DEBUG
=
False
DEBUG
=
False
...
@@ -46,6 +81,13 @@ PROFILE = False
...
@@ -46,6 +81,13 @@ PROFILE = False
NUM_PRERUN
=
10
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
NUM_ITERATIONS
=
1000
class
Inplace
(
Enum
):
OUT_OF_PLACE
=
auto
()
INPLACE_A
=
auto
()
INPLACE_B
=
auto
()
class
SwiGLUDescriptor
(
Structure
):
class
SwiGLUDescriptor
(
Structure
):
_fields_
=
[(
"device"
,
c_int32
)]
_fields_
=
[(
"device"
,
c_int32
)]
...
@@ -54,11 +96,10 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor)
...
@@ -54,11 +96,10 @@ infiniopSwiGLUDescriptor_t = POINTER(SwiGLUDescriptor)
def
swiglu
(
a
,
b
):
def
swiglu
(
a
,
b
):
return
a
*
b
/
(
1
+
torch
.
exp
(
-
b
.
float
()).
to
(
b
.
dtype
))
return
a
*
b
/
(
1
+
torch
.
exp
(
-
b
.
float
()).
to
(
b
.
dtype
))
def
test
_out_of_place
(
def
test
(
lib
,
lib
,
handle
,
handle
,
torch_device
,
torch_device
,
...
@@ -66,15 +107,21 @@ def test_out_of_place(
...
@@ -66,15 +107,21 @@ def test_out_of_place(
a_stride
=
None
,
a_stride
=
None
,
b_stride
=
None
,
b_stride
=
None
,
c_stride
=
None
,
c_stride
=
None
,
inplace
=
Inplace
.
OUT_OF_PLACE
,
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
sync
=
None
,
sync
=
None
,
):
):
print
(
print
(
f
"Testing SwiGLU on
{
torch_device
}
with shape:
{
shape
}
a_stride:
{
a_stride
}
b_stride:
{
b_stride
}
c_stride:
{
c_stride
}
dtype:
{
dtype
}
"
f
"Testing SwiGLU on
{
torch_device
}
with shape:
{
shape
}
a_stride:
{
a_stride
}
b_stride:
{
b_stride
}
c_stride:
{
c_stride
}
dtype:
{
dtype
}
"
)
)
a
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
a
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
b
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
b
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
c
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
c
=
(
torch
.
rand
(
c_shape
,
dtype
=
tensor_dtype
).
to
(
torch_device
)
if
inplace
==
Inplace
.
OUT_OF_PLACE
else
(
a
if
inplace
==
Inplace
.
INPLACE_A
else
b
)
)
ans
=
swiglu
(
a
,
b
)
ans
=
swiglu
(
a
,
b
)
...
@@ -82,9 +129,12 @@ def test_out_of_place(
...
@@ -82,9 +129,12 @@ def test_out_of_place(
rearrange_if_needed
(
tensor
,
stride
)
rearrange_if_needed
(
tensor
,
stride
)
for
tensor
,
stride
in
zip
([
a
,
b
,
c
],
[
a_stride
,
b_stride
,
c_stride
])
for
tensor
,
stride
in
zip
([
a
,
b
,
c
],
[
a_stride
,
b_stride
,
c_stride
])
]
]
a_tensor
,
b_tensor
,
c_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
a
,
b
,
c
]]
a_tensor
,
b_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
a
,
b
]]
c_tensor
=
(
to_tensor
(
c
,
lib
)
if
inplace
==
Inplace
.
OUT_OF_PLACE
else
(
a_tensor
if
inplace
==
Inplace
.
INPLACE_A
else
b_tensor
)
)
if
sync
is
not
None
:
if
sync
is
not
None
:
sync
()
sync
()
...
@@ -106,13 +156,10 @@ def test_out_of_place(
...
@@ -106,13 +156,10 @@ def test_out_of_place(
def
lib_swiglu
():
def
lib_swiglu
():
check_error
(
check_error
(
lib
.
infiniopSwiGLU
(
lib
.
infiniopSwiGLU
(
descriptor
,
descriptor
,
c_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
None
c_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
None
)
)
)
)
lib_swiglu
()
lib_swiglu
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
...
@@ -130,139 +177,7 @@ def test_out_of_place(
...
@@ -130,139 +177,7 @@ def test_out_of_place(
check_error
(
lib
.
infiniopDestroySwiGLUDescriptor
(
descriptor
))
check_error
(
lib
.
infiniopDestroySwiGLUDescriptor
(
descriptor
))
def
test_in_place1
(
lib
,
handle
,
torch_device
,
shape
,
a_stride
=
None
,
b_stride
=
None
,
dtype
=
torch
.
float16
,
sync
=
None
,
):
a
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
b
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
ans
=
swiglu
(
a
,
b
)
if
sync
is
not
None
:
sync
()
a
,
b
=
[
rearrange_if_needed
(
tensor
,
stride
)
for
tensor
,
stride
in
zip
([
a
,
b
],
[
a_stride
,
b_stride
])
]
a_tensor
,
b_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
a
,
b
]]
descriptor
=
infiniopSwiGLUDescriptor_t
()
check_error
(
lib
.
infiniopCreateSwiGLUDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
a_tensor
.
descriptor
,
a_tensor
.
descriptor
,
b_tensor
.
descriptor
,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for
tensor
in
[
a_tensor
,
b_tensor
]:
tensor
.
descriptor
.
contents
.
invalidate
()
def
lib_swiglu
():
check_error
(
lib
.
infiniopSwiGLU
(
descriptor
,
a_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
None
)
)
lib_swiglu
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
a
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
a
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
print
(
"in-place1 Test passed!"
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
swiglu
(
a
,
b
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_swiglu
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
lib
.
infiniopDestroySwiGLUDescriptor
(
descriptor
))
def
test_in_place2
(
lib
,
handle
,
torch_device
,
shape
,
a_stride
=
None
,
b_stride
=
None
,
dtype
=
torch
.
float16
,
sync
=
None
,
):
a
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
b
=
torch
.
rand
(
shape
,
dtype
=
dtype
).
to
(
torch_device
)
ans
=
swiglu
(
a
,
b
)
if
sync
is
not
None
:
sync
()
a
,
b
=
[
rearrange_if_needed
(
tensor
,
stride
)
for
tensor
,
stride
in
zip
([
a
,
b
],
[
a_stride
,
b_stride
])
]
a_tensor
,
b_tensor
=
[
to_tensor
(
tensor
,
lib
)
for
tensor
in
[
a
,
b
]]
descriptor
=
infiniopSwiGLUDescriptor_t
()
check_error
(
lib
.
infiniopCreateSwiGLUDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
b_tensor
.
descriptor
,
a_tensor
.
descriptor
,
b_tensor
.
descriptor
,
)
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for
tensor
in
[
a_tensor
,
b_tensor
]:
tensor
.
descriptor
.
contents
.
invalidate
()
def
lib_swiglu
():
check_error
(
lib
.
infiniopSwiGLU
(
descriptor
,
b_tensor
.
data
,
a_tensor
.
data
,
b_tensor
.
data
,
None
)
)
lib_swiglu
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
b
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
b
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
print
(
"in-place2 Test passed!"
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
swiglu
(
a
,
b
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_swiglu
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
lib
.
infiniopDestroySwiGLUDescriptor
(
descriptor
))
def
test
(
lib
,
handle
,
torch_device
,
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
,
sync
=
None
):
test_out_of_place
(
lib
,
handle
,
torch_device
,
shape
,
a_stride
,
b_stride
,
c_stride
,
dtype
,
sync
)
test_in_place1
(
lib
,
handle
,
torch_device
,
shape
,
a_stride
,
b_stride
,
dtype
,
sync
)
test_in_place2
(
lib
,
handle
,
torch_device
,
shape
,
a_stride
,
b_stride
,
dtype
,
sync
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
get_args
()
args
=
get_args
()
lib
=
open_lib
()
lib
=
open_lib
()
...
@@ -288,12 +203,13 @@ if __name__ == "__main__":
...
@@ -288,12 +203,13 @@ if __name__ == "__main__":
lib
.
infiniopDestroySwiGLUDescriptor
.
argtypes
=
[
lib
.
infiniopDestroySwiGLUDescriptor
.
argtypes
=
[
infiniopSwiGLUDescriptor_t
,
infiniopSwiGLUDescriptor_t
,
]
]
# Configure testing options
# Configure testing options
DEBUG
=
args
.
debug
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
for
device
in
get_test_devices
(
args
):
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
test_operator
(
lib
,
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment