Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5a52e21f
Unverified
Commit
5a52e21f
authored
Aug 12, 2022
by
Frank Lee
Committed by
GitHub
Aug 12, 2022
Browse files
[test] fixed the activation codegen test (#1447)
* [test] fixed the activation codegen test * polish code
parent
0f304236
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
7 deletions
+15
-7
colossalai/utils/common.py
colossalai/utils/common.py
+0
-1
tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
...est_fx/test_codegen/test_activation_checkpoint_codegen.py
+15
-6
No files found.
colossalai/utils/common.py
View file @
5a52e21f
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
os
from
pprint
import
pp
import
random
import
socket
from
pathlib
import
Path
...
...
tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
View file @
5a52e21f
from
operator
import
mod
import
torch
import
pytest
import
torch.multiprocessing
as
mp
from
torch.utils.checkpoint
import
checkpoint
from
torch.fx
import
GraphModule
from
colossalai.fx
import
ColoTracer
...
...
@@ -42,10 +43,9 @@ class MyModule(torch.nn.Module):
return
y1
+
y2
+
y3
+
y4
@
pytest
.
mark
.
skipif
(
not
with_codegen
,
reason
=
'torch version is lower than 1.12.0'
)
def
test_act_ckpt_codegen
():
def
_run_act_ckpt_codegen
(
rank
):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai
.
launch
(
config
=
{},
rank
=
0
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
# build model and run forward
model
=
MyModule
()
...
...
@@ -90,10 +90,14 @@ def test_act_ckpt_codegen():
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
with_codegen
,
reason
=
'torch version is equal to or higher than 1.12.0'
)
def
test_act_ckpt_python_code_torch11
():
@
pytest
.
mark
.
skipif
(
not
with_codegen
,
reason
=
'torch version is lower than 1.12.0'
)
def
test_act_ckpt_codegen
():
mp
.
spawn
(
_run_act_ckpt_codegen
,
nprocs
=
1
)
def
_run_act_ckpt_python_code_torch11
(
rank
):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai
.
launch
(
config
=
{},
rank
=
0
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
# build model and run forward
model
=
MyModule
()
...
...
@@ -138,6 +142,11 @@ def test_act_ckpt_python_code_torch11():
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
with_codegen
,
reason
=
'torch version is equal to or higher than 1.12.0'
)
def
test_act_ckpt_python_code_torch11
():
mp
.
spawn
(
_run_act_ckpt_python_code_torch11
,
nprocs
=
1
)
if
__name__
==
'__main__'
:
test_act_ckpt_codegen
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment