Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4acc58ee
Unverified
Commit
4acc58ee
authored
Aug 27, 2022
by
Boyuan Yao
Committed by
GitHub
Aug 27, 2022
Browse files
[fx] Fix activation codegen dealing with checkpointing first op (#1510)
parent
ac3a453a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
20 deletions
+36
-20
colossalai/fx/codegen/activation_checkpoint_codegen.py
colossalai/fx/codegen/activation_checkpoint_codegen.py
+7
-0
tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
...est_fx/test_codegen/test_activation_checkpoint_codegen.py
+29
-20
No files found.
colossalai/fx/codegen/activation_checkpoint_codegen.py
View file @
4acc58ee
...
...
@@ -165,9 +165,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# we need to check if the checkpoint need use_reentrant=False
use_reentrant
=
True
non_leaf_input
=
0
for
var
in
input_vars
[
label
]:
input_node
=
[
item
for
item
in
node_list
if
item
.
name
==
var
]
input_node
=
input_node
[
0
]
if
input_node
.
op
!=
"placeholder"
:
non_leaf_input
=
1
for
user
in
input_node
.
users
:
if
hasattr
(
user
,
"activation_checkpoint"
):
if
user
.
activation_checkpoint
==
label
:
...
...
@@ -179,6 +182,10 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if
"inplace"
in
user
.
kwargs
:
use_reentrant
=
not
user
.
kwargs
[
"inplace"
]
# if all the inputs are leaf nodes, we need to set use_reentrant = False
if
not
non_leaf_input
:
use_reentrant
=
False
# generate checkpoint function call in a new line
usage
=
_gen_ckpt_usage
(
label
,
activation_offload
,
input_vars
[
label
],
output_vars
[
label
],
use_reentrant
)
usage
+=
'
\n
'
...
...
tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
View file @
4acc58ee
...
...
@@ -49,16 +49,20 @@ class MyModule(torch.nn.Module):
self
.
relu
=
relu
()
self
.
linear2
=
torch
.
nn
.
Linear
(
4
,
4
)
def
forward
(
self
,
x
):
def
ckpt2
(
self
,
x
):
return
F
.
relu
(
x
,
inplace
=
True
)
def
ckpt3
(
self
,
x
,
y
):
return
self
.
linear2
(
x
)
+
self
.
linear2
(
y
)
def
forward
(
self
,
x
,
y
):
y1
,
y2
=
checkpoint
(
self
.
mlp1
,
x
)
y3
=
checkpoint
(
self
.
relu
,
x
)
def
ckpt2
(
x
):
return
F
.
relu
(
x
,
inplace
=
True
)
y4
=
checkpoint
(
ckpt2
,
x
)
y4
=
self
.
linear2
(
y4
)
return
y1
+
y2
+
y3
+
y4
y4
=
checkpoint
(
self
.
ckpt2
,
y
)
y5
=
checkpoint
(
self
.
ckpt3
,
y
,
y4
)
y6
=
self
.
linear2
(
y4
)
return
y1
+
y2
+
y3
+
y4
+
y5
+
y6
def
_run_act_ckpt_codegen
(
rank
):
...
...
@@ -67,13 +71,15 @@ def _run_act_ckpt_codegen(rank):
# build model and run forward
model
=
MyModule
()
data
=
torch
.
rand
(
4
,
4
)
data1
=
torch
.
rand
(
4
,
4
)
data2
=
torch
.
rand
(
4
,
4
)
# copy model to cuda
model
=
model
.
to
(
device
=
"cuda"
)
data
=
data
.
to
(
device
=
"cuda"
)
data1
=
data1
.
to
(
device
=
"cuda"
)
data2
=
data2
.
to
(
device
=
"cuda"
)
non_fx_out
=
model
(
data
)
non_fx_out
=
model
(
data
1
,
data2
)
# trace the module and replace codegen
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
...
...
@@ -99,12 +105,13 @@ def _run_act_ckpt_codegen(rank):
# assert checkpoint function will be generated and
# the offload option is correct
code
=
graph
.
python_code
(
'self'
).
src
assert
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=
Tru
e)'
in
code
and
\
assert
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=
Fals
e)'
in
code
and
\
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)'
in
code
and
\
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)'
in
code
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)'
in
code
and
\
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)'
in
code
# recompile and verify the outputs are consistent
fx_out
=
gm
(
data
)
fx_out
=
gm
(
data
1
,
data2
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
gpc
.
destroy
()
...
...
@@ -121,13 +128,14 @@ def _run_act_ckpt_python_code_torch11(rank):
# build model and run forward
model
=
MyModule
()
data
=
torch
.
rand
(
4
,
4
)
data1
=
torch
.
rand
(
4
,
4
)
data2
=
torch
.
rand
(
4
,
4
)
# copy model to cuda
model
=
model
.
to
(
device
=
"cuda"
)
data
=
data
.
to
(
device
=
"cuda"
)
data1
=
data1
.
to
(
device
=
"cuda"
)
data
2
=
data
2
.
to
(
device
=
"cuda"
)
non_fx_out
=
model
(
data
)
non_fx_out
=
model
(
data
1
,
data2
)
# trace the module and replace codegen
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
...
...
@@ -152,12 +160,13 @@ def _run_act_ckpt_python_code_torch11(rank):
# assert checkpoint function will be generated and
# the offload option is correct
code
=
graph
.
python_code
(
'self'
).
src
assert
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=
Tru
e)'
in
code
and
\
assert
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=
Fals
e)'
in
code
and
\
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)'
in
code
and
\
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, x, use_reentrant=False)'
in
code
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)'
in
code
and
\
'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)'
in
code
# recompile and verify the outputs are consistent
fx_out
=
gm
(
data
)
fx_out
=
gm
(
data
1
,
data2
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
gpc
.
destroy
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment