Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
c0811ed4
Commit
c0811ed4
authored
Feb 25, 2025
by
xgqdut2016
Browse files
issue/66: modified random_sample, swiglu, rms_norm, test
parent
08a29c28
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
50 deletions
+25
-50
test/infiniop/random_sample.py
test/infiniop/random_sample.py
+2
-6
test/infiniop/rms_norm.py
test/infiniop/rms_norm.py
+1
-0
test/infiniop/swiglu.py
test/infiniop/swiglu.py
+22
-44
No files found.
test/infiniop/random_sample.py
View file @
c0811ed4
...
...
@@ -188,13 +188,9 @@ def test(
# Profiling workflow
if
PROFILE
:
# fmt: off
if
topp
>
0
and
topk
>
1
:
profile_operation
(
"PyTorch"
,
lambda
:
random_sample
(
data
.
to
(
"cpu"
),
random_val
,
topp
,
topk
,
voc
,
temperature
,
"cpu"
profile_operation
(
"PyTorch"
,
lambda
:
random_sample
(
data
,
random_val
,
topp
,
topk
,
voc
,
temperature
,
torch_device
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
else
:
profile_operation
(
"PyTorch"
,
lambda
:
random_sample_0
(
data
),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_random_sample
(),
torch_device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
lib
.
infiniopDestroyRandomSampleDescriptor
(
descriptor
))
...
...
test/infiniop/rms_norm.py
View file @
c0811ed4
...
...
@@ -133,6 +133,7 @@ def test(
if
DEBUG
:
debug
(
y
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
y
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
...
...
test/infiniop/swiglu.py
View file @
c0811ed4
...
...
@@ -22,50 +22,29 @@ from enum import Enum, auto
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_
=
[
((
13
,
4
),
None
,
None
,
None
),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
)),
((
13
,
4
,
4
),
None
,
None
,
None
),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
(
20
,
4
,
1
)),
((
16
,
5632
),
None
,
None
,
None
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
)),
((
4
,
4
,
5632
),
None
,
None
,
None
),
((
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
)),
]
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE
=
[
"Inplace.OUT_OF_PLACE"
,
"Inplace.INPLACE_A"
,
"Inplace.INPLACE_B"
,
]
# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_
_TEST_CASES
=
[
# shape, a_stride, b_stride, c_stride, inplace
((
13
,
4
),
None
,
None
,
None
,
Inplace
.
OUT_OF_PLACE
),
((
13
,
4
),
None
,
None
,
None
,
Inplace
.
INPLACE_A
),
((
13
,
4
),
None
,
None
,
None
,
Inplace
.
INPLACE_B
),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
),
Inplace
.
OUT_OF_PLACE
),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
),
Inplace
.
INPLACE_A
),
((
13
,
4
),
(
10
,
1
),
(
10
,
1
),
(
10
,
1
),
Inplace
.
INPLACE_B
),
((
13
,
4
,
4
),
None
,
None
,
None
,
Inplace
.
OUT_OF_PLACE
),
((
13
,
4
,
4
),
None
,
None
,
None
,
Inplace
.
INPLACE_A
),
((
13
,
4
,
4
),
None
,
None
,
None
,
Inplace
.
INPLACE_B
),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
Inplace
.
OUT_OF_PLACE
),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
Inplace
.
INPLACE_A
),
((
13
,
4
,
4
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
(
20
,
4
,
1
),
Inplace
.
INPLACE_B
),
((
16
,
5632
),
None
,
None
,
None
,
Inplace
.
OUT_OF_PLACE
),
((
16
,
5632
),
None
,
None
,
None
,
Inplace
.
INPLACE_A
),
((
16
,
5632
),
None
,
None
,
None
,
Inplace
.
INPLACE_B
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
),
Inplace
.
OUT_OF_PLACE
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
),
Inplace
.
INPLACE_A
),
((
16
,
5632
),
(
13312
,
1
),
(
13312
,
1
),
(
13312
,
1
),
Inplace
.
INPLACE_B
),
((
4
,
4
,
5632
),
None
,
None
,
None
,
Inplace
.
OUT_OF_PLACE
),
((
4
,
4
,
5632
),
None
,
None
,
None
,
Inplace
.
INPLACE_A
),
((
4
,
4
,
5632
),
None
,
None
,
None
,
Inplace
.
INPLACE_B
),
(
(
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
Inplace
.
OUT_OF_PLACE
,
),
(
(
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
Inplace
.
INPLACE_A
,
),
(
(
4
,
4
,
5632
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
(
45056
,
5632
,
1
),
Inplace
.
INPLACE_B
,
),
test_case
+
(
inplace_item
,)
for
test_case
in
_TEST_CASES_
for
inplace_item
in
_INPLACE
]
# Data types used for testing
...
...
@@ -166,7 +145,6 @@ def test(
if
DEBUG
:
debug
(
c
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
c
,
ans
,
atol
=
atol
,
rtol
=
rtol
)
print
(
"out-of-place Test passed!"
)
# Profiling workflow
if
PROFILE
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment