Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cd063ac3
Unverified
Commit
cd063ac3
authored
Jul 25, 2022
by
Frank Lee
Committed by
GitHub
Jul 25, 2022
Browse files
[fx] added activation checkpoint codegen support for torch < 1.12 (#1359)
parent
44178041
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
441 additions
and
198 deletions
+441
-198
colossalai/fx/codegen/__init__.py
colossalai/fx/codegen/__init__.py
+1
-3
colossalai/fx/codegen/activation_checkpoint_codegen.py
colossalai/fx/codegen/activation_checkpoint_codegen.py
+403
-193
tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
...est_fx/test_codegen/test_activation_checkpoint_codegen.py
+37
-2
No files found.
colossalai/fx/codegen/__init__.py
View file @
cd063ac3
from
.activation_checkpoint_codegen
import
ActivationCheckpointCodeGen
__all__
=
[
'ActivationCheckpointCodeGen'
]
\ No newline at end of file
from
.activation_checkpoint_codegen
import
*
colossalai/fx/codegen/activation_checkpoint_codegen.py
View file @
cd063ac3
import
torch
from
typing
import
List
,
Callable
,
Any
,
Tuple
,
Dict
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
CodeGen
,
_origin_type_map
__all__
=
[
'ActivationCheckpointCodeGen'
]
try
:
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
CodeGen
,
_origin_type_map
,
inplace_methods
codegen_available
=
True
except
:
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
_origin_type_map
,
_format_args
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
codegen_available
=
False
if
codegen_available
:
__all__
=
[
'ActivationCheckpointCodeGen'
]
else
:
__all__
=
[
'python_code_with_activation_checkpoint'
]
class
ActivationCheckpointCodeGen
(
CodeGen
):
def
find_input_and_output_nodes
(
self
,
nodes
:
List
[
Node
]):
def
_
find_input_and_output_nodes
(
nodes
:
List
[
Node
]):
"""
Find the input and output node names which are not found in the given list of nodes.
"""
...
...
@@ -33,7 +41,8 @@ class ActivationCheckpointCodeGen(CodeGen):
return
input_nodes
,
output_nodes
def
find_ckpt_regions
(
self
,
nodes
:
List
[
Node
]):
def
_find_ckpt_regions
(
nodes
:
List
[
Node
]):
"""
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
...
...
@@ -75,19 +84,22 @@ class ActivationCheckpointCodeGen(CodeGen):
pass
return
ckpt_regions
def
gen_ckpt_fn_def
(
self
,
label
,
free_vars
:
List
[
str
])
->
str
:
def
_gen_ckpt_fn_def
(
label
,
free_vars
:
List
[
str
])
->
str
:
"""
Generate the checkpoint function definition
"""
return
f
"def checkpoint_
{
label
}
(
{
', '
.
join
(
free_vars
)
}
):"
def
gen_ckpt_output
(
self
,
output_vars
:
List
[
str
])
->
str
:
def
_gen_ckpt_output
(
output_vars
:
List
[
str
])
->
str
:
"""
Generate the return statement for checkpoint region
"""
return
f
"return
{
', '
.
join
(
output_vars
)
}
"
def
gen_ckpt_usage
(
self
,
label
,
input_vars
,
output_vars
):
def
_gen_ckpt_usage
(
label
,
input_vars
,
output_vars
):
"""
Generate the checkpoint function call code text
"""
...
...
@@ -95,6 +107,65 @@ class ActivationCheckpointCodeGen(CodeGen):
inputs
=
', '
.
join
(
input_vars
)
return
f
'
{
outputs
}
= torch.utils.checkpoint.checkpoint(checkpoint_
{
label
}
,
{
inputs
}
)'
def
emit_code_with_activation_checkpoint
(
body
,
nodes
,
emit_node_func
,
delete_unused_value_func
):
# find the activation checkpoint regions
ckpt_regions
=
_find_ckpt_regions
(
nodes
)
start_idx
=
[
item
[
0
]
for
item
in
ckpt_regions
]
end_idx
=
[
item
[
1
]
for
item
in
ckpt_regions
]
input_vars
=
[]
output_vars
=
[]
within_ckpt_region
=
False
node_list
=
list
(
nodes
)
# find the input and output var names for each region
for
idx
,
(
start
,
end
)
in
enumerate
(
ckpt_regions
):
ckpt_node_list
=
node_list
[
start
:
end
+
1
]
inputs
,
outputs
=
_find_input_and_output_nodes
(
ckpt_node_list
)
input_vars
.
append
(
inputs
)
output_vars
.
append
(
outputs
)
# append code text to body
for
idx
,
node
in
enumerate
(
node_list
):
# if this is the first node of the ckpt region
# append the ckpt function defition
if
idx
in
start_idx
:
label
=
start_idx
.
index
(
idx
)
ckpt_fn_def
=
_gen_ckpt_fn_def
(
label
,
input_vars
[
label
])
body
.
append
(
f
'
{
ckpt_fn_def
}
\n
'
)
within_ckpt_region
=
True
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
emit_node_func
(
node
)
# add indentation to the emmited node
if
within_ckpt_region
:
body
[
-
1
]
=
' '
+
body
[
-
1
]
# delete unused values
delete_unused_value_func
(
node
)
if
idx
in
end_idx
:
# if this is the last node of the ckpt region
# generate return statement
label
=
end_idx
.
index
(
idx
)
return_statement
=
_gen_ckpt_output
(
output_vars
[
label
])
return_statement
=
f
'
{
return_statement
}
\n
'
body
.
append
(
return_statement
)
# generate checkpoint function call in a new line
usage
=
_gen_ckpt_usage
(
label
,
input_vars
[
label
],
output_vars
[
label
])
usage
+=
'
\n
'
body
.
append
(
usage
)
within_ckpt_region
=
False
if
codegen_available
:
class
ActivationCheckpointCodeGen
(
CodeGen
):
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
free_vars
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
...
...
@@ -223,7 +294,8 @@ class ActivationCheckpointCodeGen(CodeGen):
return
elif
node
.
op
==
'call_method'
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
'
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
'
f
'(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)'
)
return
elif
node
.
op
==
'call_function'
:
...
...
@@ -275,70 +347,212 @@ class ActivationCheckpointCodeGen(CodeGen):
return
raise
NotImplementedError
(
f
'node:
{
node
.
op
}
{
node
.
target
}
'
)
#########################################
# Modified for activation checkpointing #
#########################################
# find the activation checkpoint regions
ckpt_regions
=
self
.
find_ckpt_regions
(
nodes
)
start_idx
=
[
item
[
0
]
for
item
in
ckpt_regions
]
end_idx
=
[
item
[
1
]
for
item
in
ckpt_regions
]
input_vars
=
[]
output_vars
=
[]
within_ckpt_region
=
False
# Modified for activation checkpointing
emit_code_with_activation_checkpoint
(
body
,
nodes
,
emit_node
,
delete_unused_values
)
node_list
=
list
(
nodes
)
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body
.
append
(
'pass
\n
'
)
# find the input and output var names for each region
for
idx
,
(
start
,
end
)
in
enumerate
(
ckpt_regions
):
ckpt_node_list
=
node_list
[
start
:
end
+
1
]
inputs
,
outputs
=
self
.
find_input_and_output_nodes
(
ckpt_node_list
)
input_vars
.
append
(
inputs
)
output_vars
.
append
(
outputs
)
if
len
(
wrapped_fns
)
>
0
:
wrap_name
=
add_global
(
'wrap'
,
torch
.
fx
.
wrap
)
wrap_stmts
=
'
\n
'
.
join
([
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
])
else
:
wrap_stmts
=
''
# append code text to body
for
idx
,
node
in
enumerate
(
node_list
):
# if this is the first node of the ckpt region
# append the ckpt function defition
if
idx
in
start_idx
:
label
=
start_idx
.
index
(
idx
)
ckpt_fn_def
=
self
.
gen_ckpt_fn_def
(
label
,
input_vars
[
label
])
body
.
append
(
f
'
{
ckpt_fn_def
}
\n
'
)
within_ckpt_region
=
True
if
self
.
_body_transformer
:
body
=
self
.
_body_transformer
(
body
)
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
emit_node
(
node
)
for
name
,
value
in
self
.
additional_globals
():
add_global
(
name
,
value
)
# add indentation to the emmited node
if
within_ckpt_region
:
body
[
-
1
]
=
' '
+
body
[
-
1
]
prologue
=
self
.
gen_fn_def
(
free_vars
,
maybe_return_annotation
[
0
])
# delete unused values
delete_unused_values
(
node
)
code
=
''
.
join
(
body
)
code
=
'
\n
'
.
join
(
' '
+
line
for
line
in
code
.
split
(
'
\n
'
))
fn_code
=
f
"""
{
wrap_stmts
}
if
idx
in
end_idx
:
# if this is the last node of the ckpt region
# generate return statement
label
=
end_idx
.
index
(
idx
)
return_statement
=
self
.
gen_ckpt_output
(
output_vars
[
label
])
return_statement
=
f
'
{
return_statement
}
\n
'
body
.
append
(
return_statement
)
{
prologue
}
{
code
}
"""
return
PythonCode
(
fn_code
,
globals_
)
# generate checkpoint function call in a new line
usage
=
self
.
gen_ckpt_usage
(
label
,
input_vars
[
label
],
output_vars
[
label
])
usage
+=
'
\n
'
body
.
append
(
usage
)
within_ckpt_region
=
False
else
:
def
python_code_with_activation_checkpoint
(
self
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
"""
This method is copied from the _python_code of torch.fx.graph.Graph. Modifications are made so that it can generate
code for activation checkpoint.
"""
free_vars
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
globals_
:
Dict
[
str
,
Any
]
=
{}
wrapped_fns
:
Dict
[
str
,
None
]
=
{}
#######################################################
# Code Change For Activation Checkpointing Stops Here #
#######################################################
# Wrap string in list to pass by reference
maybe_return_annotation
:
List
[
str
]
=
[
''
]
def
add_global
(
name_hint
:
str
,
obj
:
Any
):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
:
# to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return
_get_qualified_name
(
obj
)
# normalize the name hint to get a proper identifier
global_name
=
namespace
.
create_name
(
name_hint
,
obj
)
if
global_name
in
globals_
:
assert
globals_
[
global_name
]
is
obj
return
global_name
globals_
[
global_name
]
=
obj
return
global_name
# Pre-fill the globals table with registered builtins.
for
name
,
(
_
,
obj
)
in
_custom_builtins
.
items
():
add_global
(
name
,
obj
)
def
type_repr
(
o
:
Any
):
if
o
==
():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return
'()'
typename
=
_type_repr
(
o
)
# This is a generic type, e.g. typing.List[torch.Tensor]
if
hasattr
(
o
,
'__origin__'
):
origin_type
=
_origin_type_map
.
get
(
o
.
__origin__
,
o
.
__origin__
)
origin_typename
=
add_global
(
_type_repr
(
origin_type
),
origin_type
)
# Assign global names for each of the inner type variables.
args
=
[
type_repr
(
arg
)
for
arg
in
o
.
__args__
]
return
f
'
{
origin_typename
}
[
{
","
.
join
(
args
)
}
]'
# Common case: this is a regular module name like 'foo.bar.baz'
return
add_global
(
typename
,
o
)
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
self
.
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
def
delete_unused_values
(
user
:
Node
):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if
user
.
op
==
'placeholder'
:
return
if
user
.
op
==
'output'
:
body
.
append
(
'
\n
'
)
return
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
len
(
nodes_to_delete
):
to_delete_str
=
' = '
.
join
([
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
'None'
])
body
.
append
(
f
';
{
to_delete_str
}
\n
'
)
else
:
body
.
append
(
'
\n
'
)
def
emit_node
(
node
:
Node
):
maybe_type_annotation
=
''
if
node
.
type
is
None
else
f
' :
{
type_repr
(
node
.
type
)
}
'
if
node
.
op
==
'placeholder'
:
assert
isinstance
(
node
.
target
,
str
)
maybe_default_arg
=
''
if
not
node
.
args
else
f
' =
{
repr
(
node
.
args
[
0
])
}
'
free_vars
.
append
(
f
'
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
'
)
raw_name
=
node
.
target
.
replace
(
'*'
,
''
)
if
raw_name
!=
repr
(
node
):
body
.
append
(
f
'
{
repr
(
node
)
}
=
{
raw_name
}
\n
'
)
return
elif
node
.
op
==
'call_method'
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
'
f
'(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)'
)
return
elif
node
.
op
==
'call_function'
:
assert
callable
(
node
.
target
)
# pretty print operators
if
node
.
target
.
__module__
==
'_operator'
and
node
.
target
.
__name__
in
magic_methods
:
assert
isinstance
(
node
.
args
,
tuple
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
= '
f
'
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
'
)
return
qualified_name
=
_get_qualified_name
(
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if
global_name
==
'getattr'
and
\
isinstance
(
node
.
args
,
tuple
)
and
\
isinstance
(
node
.
args
[
1
],
str
)
and
\
node
.
args
[
1
].
isidentifier
()
and
\
len
(
node
.
args
)
==
2
:
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
args
[
1
])
}
'
)
return
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)'
)
if
node
.
meta
.
get
(
'is_wrapped'
,
False
):
wrapped_fns
.
setdefault
(
global_name
)
return
elif
node
.
op
==
'call_module'
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
= '
f
'
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)'
)
return
elif
node
.
op
==
'get_attr'
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
root_module
,
node
.
target
)
}
'
)
return
elif
node
.
op
==
'output'
:
if
node
.
type
is
not
None
:
maybe_return_annotation
[
0
]
=
f
" ->
{
type_repr
(
node
.
type
)
}
"
if
self
.
_pytree_info
is
None
:
body
.
append
(
f
'return
{
repr
(
node
.
args
[
0
])
}
'
)
else
:
body
.
append
(
f
'return pytree.tree_unflatten(
{
repr
(
node
.
args
[
0
])
}
, self._out_spec)'
)
return
raise
NotImplementedError
(
f
'node:
{
node
.
op
}
{
node
.
target
}
'
)
# Modified for activation checkpointing
emit_code_with_activation_checkpoint
(
body
,
self
.
nodes
,
emit_node
,
delete_unused_values
)
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body
.
append
(
'pass
\n
'
)
if
self
.
_pytree_info
is
not
None
:
orig_args
=
self
.
_pytree_info
.
orig_args
has_orig_self
=
(
orig_args
[
0
]
==
'self'
)
if
has_orig_self
:
free_vars
.
insert
(
0
,
'self'
)
if
len
(
free_vars
)
>
0
:
# pytree has placeholders in it
body
.
insert
(
0
,
f
"
{
', '
.
join
(
free_vars
)
}
, = fx_pytree.tree_flatten_spec([
{
', '
.
join
(
orig_args
)
}
], self._in_spec)
\n
"
)
else
:
orig_args
=
free_vars
if
len
(
wrapped_fns
)
>
0
:
wrap_name
=
add_global
(
'wrap'
,
torch
.
fx
.
wrap
)
...
...
@@ -346,19 +560,15 @@ class ActivationCheckpointCodeGen(CodeGen):
else
:
wrap_stmts
=
''
if
self
.
_body_transformer
:
body
=
self
.
_body_transformer
(
body
)
for
name
,
value
in
self
.
additional_globals
():
add_global
(
name
,
value
)
prologue
=
self
.
gen_fn_def
(
free_vars
,
maybe_return_annotation
[
0
])
# If the original function didn't have self as its first argument, we
# would have added it.
if
len
(
orig_args
)
==
0
or
orig_args
[
0
]
!=
'self'
:
orig_args
.
insert
(
0
,
'self'
)
code
=
''
.
join
(
body
)
code
=
'
\n
'
.
join
(
' '
+
line
for
line
in
code
.
split
(
'
\n
'
))
fn_code
=
f
"""
{
wrap_stmts
}
{
prologue
}
def forward(
{
', '
.
join
(
orig_args
)
}
)
{
maybe_return_annotation
[
0
]
}
:
{
code
}
"""
return
PythonCode
(
fn_code
,
globals_
)
tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
View file @
cd063ac3
...
...
@@ -6,8 +6,11 @@ from colossalai.fx import ColoTracer
try
:
from
colossalai.fx.codegen
import
ActivationCheckpointCodeGen
with_codegen
=
True
except
:
pass
# fall back to older pytorch version
from
colossalai.fx.codegen
import
python_code_with_activation_checkpoint
with_codegen
=
False
class
MLP
(
torch
.
nn
.
Module
):
...
...
@@ -35,7 +38,7 @@ class MyModule(torch.nn.Module):
return
y1
+
y2
+
y3
+
y4
@
pytest
.
mark
.
skip
(
"torch 1.12 is required"
)
@
pytest
.
mark
.
skip
if
(
not
with_codegen
,
reason
=
'torch version is lower than 1.12.0'
)
def
test_act_ckpt_codegen
():
# build model and run forward
model
=
MyModule
()
...
...
@@ -65,5 +68,37 @@ def test_act_ckpt_codegen():
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
@
pytest
.
mark
.
skipif
(
with_codegen
,
reason
=
'torch version is equal to or higher than 1.12.0'
)
def
test_act_ckpt_python_code_torch11
():
# build model and run forward
model
=
MyModule
()
data
=
torch
.
rand
(
4
,
4
)
non_fx_out
=
model
(
data
)
# trace the module and replace codegen
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
graph
=
tracer
.
trace
(
model
)
# replace a bound method of an object
graph
.
_python_code
=
python_code_with_activation_checkpoint
.
__get__
(
graph
)
# check ops are annotated with ckpt
ckpt_nodes
=
[
'mlp1_linear1'
,
'mlp1_linear1_1'
,
'mlp2_linear1'
,
'mlp2_linear1_1'
]
for
node
in
graph
.
nodes
:
if
node
.
name
in
ckpt_nodes
:
assert
hasattr
(
node
,
'activation_checkpoint'
)
# assert checkpoint function will be generated
code
=
graph
.
python_code
(
'self'
).
src
assert
'checkpoint_0'
in
code
and
'checkpoint_1'
in
code
# recompile and verify the outputs are consistent
gm
=
GraphModule
(
model
,
graph
)
gm
.
recompile
()
fx_out
=
gm
(
data
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
if
__name__
==
'__main__'
:
test_act_ckpt_codegen
()
test_act_ckpt_python_code_torch11
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment