Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cd063ac3
Unverified
Commit
cd063ac3
authored
Jul 25, 2022
by
Frank Lee
Committed by
GitHub
Jul 25, 2022
Browse files
[fx] added activation checkpoint codegen support for torch < 1.12 (#1359)
parent
44178041
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
441 additions
and
198 deletions
+441
-198
colossalai/fx/codegen/__init__.py
colossalai/fx/codegen/__init__.py
+1
-3
colossalai/fx/codegen/activation_checkpoint_codegen.py
colossalai/fx/codegen/activation_checkpoint_codegen.py
+403
-193
tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
...est_fx/test_codegen/test_activation_checkpoint_codegen.py
+37
-2
No files found.
colossalai/fx/codegen/__init__.py
View file @
cd063ac3
from
.activation_checkpoint_codegen
import
ActivationCheckpointCodeGen
from
.activation_checkpoint_codegen
import
*
__all__
=
[
'ActivationCheckpointCodeGen'
]
\ No newline at end of file
colossalai/fx/codegen/activation_checkpoint_codegen.py
View file @
cd063ac3
This diff is collapsed.
Click to expand it.
tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
View file @
cd063ac3
...
@@ -6,8 +6,11 @@ from colossalai.fx import ColoTracer
...
@@ -6,8 +6,11 @@ from colossalai.fx import ColoTracer
try
:
try
:
from
colossalai.fx.codegen
import
ActivationCheckpointCodeGen
from
colossalai.fx.codegen
import
ActivationCheckpointCodeGen
with_codegen
=
True
except
:
except
:
pass
# fall back to older pytorch version
from
colossalai.fx.codegen
import
python_code_with_activation_checkpoint
with_codegen
=
False
class
MLP
(
torch
.
nn
.
Module
):
class
MLP
(
torch
.
nn
.
Module
):
...
@@ -35,7 +38,7 @@ class MyModule(torch.nn.Module):
...
@@ -35,7 +38,7 @@ class MyModule(torch.nn.Module):
return
y1
+
y2
+
y3
+
y4
return
y1
+
y2
+
y3
+
y4
@
pytest
.
mark
.
skip
(
"torch 1.12 is required"
)
@
pytest
.
mark
.
skip
if
(
not
with_codegen
,
reason
=
'torch version is lower than 1.12.0'
)
def
test_act_ckpt_codegen
():
def
test_act_ckpt_codegen
():
# build model and run forward
# build model and run forward
model
=
MyModule
()
model
=
MyModule
()
...
@@ -65,5 +68,37 @@ def test_act_ckpt_codegen():
...
@@ -65,5 +68,37 @@ def test_act_ckpt_codegen():
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
@
pytest
.
mark
.
skipif
(
with_codegen
,
reason
=
'torch version is equal to or higher than 1.12.0'
)
def
test_act_ckpt_python_code_torch11
():
# build model and run forward
model
=
MyModule
()
data
=
torch
.
rand
(
4
,
4
)
non_fx_out
=
model
(
data
)
# trace the module and replace codegen
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
graph
=
tracer
.
trace
(
model
)
# replace a bound method of an object
graph
.
_python_code
=
python_code_with_activation_checkpoint
.
__get__
(
graph
)
# check ops are annotated with ckpt
ckpt_nodes
=
[
'mlp1_linear1'
,
'mlp1_linear1_1'
,
'mlp2_linear1'
,
'mlp2_linear1_1'
]
for
node
in
graph
.
nodes
:
if
node
.
name
in
ckpt_nodes
:
assert
hasattr
(
node
,
'activation_checkpoint'
)
# assert checkpoint function will be generated
code
=
graph
.
python_code
(
'self'
).
src
assert
'checkpoint_0'
in
code
and
'checkpoint_1'
in
code
# recompile and verify the outputs are consistent
gm
=
GraphModule
(
model
,
graph
)
gm
.
recompile
()
fx_out
=
gm
(
data
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_act_ckpt_codegen
()
test_act_ckpt_codegen
()
test_act_ckpt_python_code_torch11
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment