Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3abbaf8b
Commit
3abbaf8b
authored
Jan 09, 2023
by
oahzxl
Browse files
update codegen test
parent
74b81395
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
7 deletions
+16
-7
tests/test_autochunk/test_autochunk_codegen.py
tests/test_autochunk/test_autochunk_codegen.py
+16
-7
No files found.
tests/test_autochunk/test_autochunk_codegen.py
View file @
3abbaf8b
from
functools
import
partial
import
pytest
import
pytest
import
torch
import
torch
import
torch.fx
import
torch.fx
...
@@ -46,7 +48,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
...
@@ -46,7 +48,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
)
)
def
_test_autochunk_codegen
(
rank
):
def
_test_autochunk_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai
.
launch
(
colossalai
.
launch
(
config
=
{},
config
=
{},
...
@@ -59,8 +61,6 @@ def _test_autochunk_codegen(rank):
...
@@ -59,8 +61,6 @@ def _test_autochunk_codegen(rank):
# build model and input
# build model and input
model
=
evoformer_base
().
cuda
()
model
=
evoformer_base
().
cuda
()
msa_len
=
32
pair_len
=
64
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
...
@@ -85,7 +85,7 @@ def _test_autochunk_codegen(rank):
...
@@ -85,7 +85,7 @@ def _test_autochunk_codegen(rank):
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
)
codegen
=
AutoChunkCodeGen
(
gm_prop
)
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
=
max_memory
)
graph
.
set_codegen
(
codegen
)
graph
.
set_codegen
(
codegen
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
gm
.
recompile
()
...
@@ -99,9 +99,18 @@ def _test_autochunk_codegen(rank):
...
@@ -99,9 +99,18 @@ def _test_autochunk_codegen(rank):
gpc
.
destroy
()
gpc
.
destroy
()
def
test_autochunk_codegen
():
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
24
,
28
,
32
])
mp
.
spawn
(
_test_autochunk_codegen
,
nprocs
=
1
)
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
def
test_autochunk_codegen
(
msa_len
,
pair_len
,
max_memory
):
run_func
=
partial
(
_test_autochunk_codegen
,
msa_len
=
msa_len
,
pair_len
=
pair_len
,
max_memory
=
max_memory
,
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
_test_autochunk_codegen
(
0
)
_test_autochunk_codegen
(
0
,
32
,
64
,
None
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment