Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cd063ac3
"vscode:/vscode.git/clone" did not exist on "e10d9f087e89c62fea223bd81283f13107b66c3f"
Unverified
Commit
cd063ac3
authored
Jul 25, 2022
by
Frank Lee
Committed by
GitHub
Jul 25, 2022
Browse files
[fx] added activation checkpoint codegen support for torch < 1.12 (#1359)
parent
44178041
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
441 additions
and
198 deletions
+441
-198
colossalai/fx/codegen/__init__.py
colossalai/fx/codegen/__init__.py
+1
-3
colossalai/fx/codegen/activation_checkpoint_codegen.py
colossalai/fx/codegen/activation_checkpoint_codegen.py
+403
-193
tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
...est_fx/test_codegen/test_activation_checkpoint_codegen.py
+37
-2
No files found.
colossalai/fx/codegen/__init__.py
View file @
cd063ac3
from
.activation_checkpoint_codegen
import
ActivationCheckpointCodeGen
from
.activation_checkpoint_codegen
import
*
__all__
=
[
'ActivationCheckpointCodeGen'
]
\ No newline at end of file
colossalai/fx/codegen/activation_checkpoint_codegen.py
View file @
cd063ac3
import
torch
import
torch
from
typing
import
List
,
Callable
,
Any
,
Tuple
,
Dict
from
typing
import
List
,
Callable
,
Any
,
Tuple
,
Dict
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
CodeGen
,
_origin_type_map
__all__
=
[
'ActivationCheckpointCodeGen'
]
try
:
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
CodeGen
,
_origin_type_map
,
inplace_methods
codegen_available
=
True
except
:
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
_origin_type_map
,
_format_args
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
codegen_available
=
False
if
codegen_available
:
__all__
=
[
'ActivationCheckpointCodeGen'
]
else
:
__all__
=
[
'python_code_with_activation_checkpoint'
]
def
_find_input_and_output_nodes
(
nodes
:
List
[
Node
]):
"""
Find the input and output node names which are not found in the given list of nodes.
"""
input_nodes
=
[]
output_nodes
=
[]
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
node_repr
=
repr
(
input_node
)
if
input_node
not
in
nodes
and
node_repr
not
in
input_nodes
:
input_nodes
.
append
(
node_repr
)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for
node
in
nodes
:
for
output_node
in
node
.
users
.
keys
():
node_repr
=
repr
(
node
)
if
output_node
not
in
nodes
and
node_repr
not
in
output_nodes
:
output_nodes
.
append
(
node_repr
)
return
input_nodes
,
output_nodes
def
_find_ckpt_regions
(
nodes
:
List
[
Node
]):
"""
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_nodes
=
[]
ckpt_regions
=
[]
start
=
-
1
end
=
-
1
current_region
=
None
for
idx
,
node
in
enumerate
(
nodes
):
if
hasattr
(
node
,
'activation_checkpoint'
):
act_ckpt_label
=
node
.
activation_checkpoint
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
if
current_region
is
None
:
current_region
=
act_ckpt_label
start
=
idx
# if activation checkpoint has changed
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
if
act_ckpt_label
!=
current_region
:
assert
start
!=
-
1
ckpt_regions
.
append
((
start
,
idx
-
1
))
current_region
=
act_ckpt_label
start
=
idx
end
=
-
1
elif
current_region
is
not
None
and
not
hasattr
(
node
,
'activation_checkpoint'
):
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end
=
idx
-
1
assert
start
!=
-
1
and
end
!=
-
1
ckpt_regions
.
append
((
start
,
end
))
start
=
end
=
-
1
current_region
=
None
else
:
pass
return
ckpt_regions
def
_gen_ckpt_fn_def
(
label
,
free_vars
:
List
[
str
])
->
str
:
"""
Generate the checkpoint function definition
"""
return
f
"def checkpoint_
{
label
}
(
{
', '
.
join
(
free_vars
)
}
):"
def
_gen_ckpt_output
(
output_vars
:
List
[
str
])
->
str
:
"""
Generate the return statement for checkpoint region
"""
return
f
"return
{
', '
.
join
(
output_vars
)
}
"
def
_gen_ckpt_usage
(
label
,
input_vars
,
output_vars
):
"""
Generate the checkpoint function call code text
"""
outputs
=
', '
.
join
(
output_vars
)
inputs
=
', '
.
join
(
input_vars
)
return
f
'
{
outputs
}
= torch.utils.checkpoint.checkpoint(checkpoint_
{
label
}
,
{
inputs
}
)'
def
emit_code_with_activation_checkpoint
(
body
,
nodes
,
emit_node_func
,
delete_unused_value_func
):
# find the activation checkpoint regions
ckpt_regions
=
_find_ckpt_regions
(
nodes
)
start_idx
=
[
item
[
0
]
for
item
in
ckpt_regions
]
end_idx
=
[
item
[
1
]
for
item
in
ckpt_regions
]
input_vars
=
[]
output_vars
=
[]
within_ckpt_region
=
False
node_list
=
list
(
nodes
)
# find the input and output var names for each region
for
idx
,
(
start
,
end
)
in
enumerate
(
ckpt_regions
):
ckpt_node_list
=
node_list
[
start
:
end
+
1
]
inputs
,
outputs
=
_find_input_and_output_nodes
(
ckpt_node_list
)
input_vars
.
append
(
inputs
)
output_vars
.
append
(
outputs
)
# append code text to body
for
idx
,
node
in
enumerate
(
node_list
):
# if this is the first node of the ckpt region
# append the ckpt function defition
if
idx
in
start_idx
:
label
=
start_idx
.
index
(
idx
)
ckpt_fn_def
=
_gen_ckpt_fn_def
(
label
,
input_vars
[
label
])
body
.
append
(
f
'
{
ckpt_fn_def
}
\n
'
)
within_ckpt_region
=
True
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
emit_node_func
(
node
)
# add indentation to the emmited node
if
within_ckpt_region
:
body
[
-
1
]
=
' '
+
body
[
-
1
]
# delete unused values
delete_unused_value_func
(
node
)
if
idx
in
end_idx
:
# if this is the last node of the ckpt region
# generate return statement
label
=
end_idx
.
index
(
idx
)
return_statement
=
_gen_ckpt_output
(
output_vars
[
label
])
return_statement
=
f
'
{
return_statement
}
\n
'
body
.
append
(
return_statement
)
# generate checkpoint function call in a new line
usage
=
_gen_ckpt_usage
(
label
,
input_vars
[
label
],
output_vars
[
label
])
usage
+=
'
\n
'
body
.
append
(
usage
)
within_ckpt_region
=
False
if
codegen_available
:
class
ActivationCheckpointCodeGen
(
CodeGen
):
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
free_vars
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
globals_
:
Dict
[
str
,
Any
]
=
{}
wrapped_fns
:
Dict
[
str
,
None
]
=
{}
# Wrap string in list to pass by reference
maybe_return_annotation
:
List
[
str
]
=
[
''
]
def
add_global
(
name_hint
:
str
,
obj
:
Any
):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
:
# to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return
_get_qualified_name
(
obj
)
# normalize the name hint to get a proper identifier
global_name
=
namespace
.
create_name
(
name_hint
,
obj
)
if
global_name
in
globals_
:
assert
globals_
[
global_name
]
is
obj
return
global_name
globals_
[
global_name
]
=
obj
return
global_name
# Pre-fill the globals table with registered builtins.
for
name
,
(
_
,
obj
)
in
_custom_builtins
.
items
():
add_global
(
name
,
obj
)
class
ActivationCheckpointCodeGen
(
CodeGen
):
def
type_repr
(
o
:
Any
):
if
o
==
():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return
'()'
def
find_input_and_output_nodes
(
self
,
nodes
:
List
[
Node
]):
typename
=
_type_repr
(
o
)
"""
Find the input and output node names which are not found in the given list of nodes.
if
hasattr
(
o
,
'__origin__'
):
"""
# This is a generic type, e.g. typing.List[torch.Tensor]
input_nodes
=
[]
origin_type
=
_origin_type_map
.
get
(
o
.
__origin__
,
o
.
__origin__
)
output_nodes
=
[]
origin_typename
=
add_global
(
_type_repr
(
origin_type
),
origin_type
)
# if a node has an input node which is not in the node list
if
hasattr
(
o
,
'__args__'
):
# we treat that input node as the input of the checkpoint function
# Assign global names for each of the inner type variables.
for
node
in
nodes
:
args
=
[
type_repr
(
arg
)
for
arg
in
o
.
__args__
]
for
input_node
in
node
.
_input_nodes
.
keys
():
node_repr
=
repr
(
input_node
)
if
len
(
args
)
==
0
:
if
input_node
not
in
nodes
and
node_repr
not
in
input_nodes
:
# Bare type, such as `typing.Tuple` with no subscript
input_nodes
.
append
(
node_repr
)
# This code-path used in Python < 3.9
return
origin_typename
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
return
f
'
{
origin_typename
}
[
{
","
.
join
(
args
)
}
]'
for
node
in
nodes
:
else
:
for
output_node
in
node
.
users
.
keys
():
# Bare type, such as `typing.Tuple` with no subscript
node_repr
=
repr
(
node
)
# This code-path used in Python 3.9+
if
output_node
not
in
nodes
and
node_repr
not
in
output_nodes
:
return
origin_typename
output_nodes
.
append
(
node_repr
)
# Common case: this is a regular module name like 'foo.bar.baz'
return
input_nodes
,
output_nodes
return
add_global
(
typename
,
o
)
def
find_ckpt_regions
(
self
,
nodes
:
List
[
Node
]):
def
_format_args
(
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Argument
])
->
str
:
"""
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
def
_get_repr
(
arg
):
of tuples, each tuple is in the form of (start_index, end_index).
# Handle NamedTuples (if it has `_fields`) via add_global.
"""
if
isinstance
(
arg
,
tuple
)
and
hasattr
(
arg
,
'_fields'
):
ckpt_nodes
=
[]
qualified_name
=
_get_qualified_name
(
type
(
arg
))
ckpt_regions
=
[]
global_name
=
add_global
(
qualified_name
,
type
(
arg
))
start
=
-
1
return
f
"
{
global_name
}{
repr
(
tuple
(
arg
))
}
"
end
=
-
1
return
repr
(
arg
)
current_region
=
None
args_s
=
', '
.
join
(
_get_repr
(
a
)
for
a
in
args
)
for
idx
,
node
in
enumerate
(
nodes
):
kwargs_s
=
', '
.
join
(
f
'
{
k
}
=
{
_get_repr
(
v
)
}
'
for
k
,
v
in
kwargs
.
items
())
if
hasattr
(
node
,
'activation_checkpoint'
):
if
args_s
and
kwargs_s
:
act_ckpt_label
=
node
.
activation_checkpoint
return
f
'
{
args_s
}
,
{
kwargs_s
}
'
return
args_s
or
kwargs_s
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
# Run through reverse nodes and record the first instance of a use
if
current_region
is
None
:
# of a given node. This represents the *last* use of the node in the
current_region
=
act_ckpt_label
# execution order of the program, which we will use to free unused
start
=
idx
# values
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
# if activation checkpoint has changed
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
act_ckpt_label
!=
current_region
:
if
n
not
in
node_to_last_use
:
assert
start
!=
-
1
node_to_last_use
[
n
]
=
user
ckpt_regions
.
append
((
start
,
idx
-
1
))
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
current_region
=
act_ckpt_label
start
=
idx
for
node
in
reversed
(
nodes
):
end
=
-
1
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
elif
current_region
is
not
None
and
not
hasattr
(
node
,
'activation_checkpoint'
):
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
def
delete_unused_values
(
user
:
Node
):
end
=
idx
-
1
"""
assert
start
!=
-
1
and
end
!=
-
1
Delete values after their last use. This ensures that values that are
ckpt_regions
.
append
((
start
,
end
))
not used in the remainder of the code are freed and the memory usage
start
=
end
=
-
1
of the code is optimal.
current_region
=
None
"""
if
user
.
op
==
'placeholder'
:
return
if
user
.
op
==
'output'
:
body
.
append
(
'
\n
'
)
return
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
len
(
nodes_to_delete
):
to_delete_str
=
' = '
.
join
([
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
'None'
])
body
.
append
(
f
';
{
to_delete_str
}
\n
'
)
else
:
body
.
append
(
'
\n
'
)
def
emit_node
(
node
:
Node
):
maybe_type_annotation
=
''
if
node
.
type
is
None
else
f
' :
{
type_repr
(
node
.
type
)
}
'
if
node
.
op
==
'placeholder'
:
assert
isinstance
(
node
.
target
,
str
)
maybe_default_arg
=
''
if
not
node
.
args
else
f
' =
{
repr
(
node
.
args
[
0
])
}
'
free_vars
.
append
(
f
'
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
'
)
raw_name
=
node
.
target
.
replace
(
'*'
,
''
)
if
raw_name
!=
repr
(
node
):
body
.
append
(
f
'
{
repr
(
node
)
}
=
{
raw_name
}
\n
'
)
return
elif
node
.
op
==
'call_method'
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
'
f
'(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)'
)
return
elif
node
.
op
==
'call_function'
:
assert
callable
(
node
.
target
)
# pretty print operators
if
node
.
target
.
__module__
==
'_operator'
and
node
.
target
.
__name__
in
magic_methods
:
assert
isinstance
(
node
.
args
,
tuple
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
= '
f
'
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
'
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if
node
.
target
.
__module__
==
'_operator'
and
node
.
target
.
__name__
in
inplace_methods
:
body
.
append
(
f
'
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
; '
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
])
}
'
)
return
qualified_name
=
_get_qualified_name
(
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if
global_name
==
'getattr'
and
\
isinstance
(
node
.
args
,
tuple
)
and
\
isinstance
(
node
.
args
[
1
],
str
)
and
\
node
.
args
[
1
].
isidentifier
()
and
\
len
(
node
.
args
)
==
2
:
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
args
[
1
])
}
'
)
return
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)'
)
if
node
.
meta
.
get
(
'is_wrapped'
,
False
):
wrapped_fns
.
setdefault
(
global_name
)
return
elif
node
.
op
==
'call_module'
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
= '
f
'
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)'
)
return
elif
node
.
op
==
'get_attr'
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
root_module
,
node
.
target
)
}
'
)
return
elif
node
.
op
==
'output'
:
if
node
.
type
is
not
None
:
maybe_return_annotation
[
0
]
=
f
" ->
{
type_repr
(
node
.
type
)
}
"
body
.
append
(
self
.
generate_output
(
node
.
args
[
0
]))
return
raise
NotImplementedError
(
f
'node:
{
node
.
op
}
{
node
.
target
}
'
)
# Modified for activation checkpointing
emit_code_with_activation_checkpoint
(
body
,
nodes
,
emit_node
,
delete_unused_values
)
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body
.
append
(
'pass
\n
'
)
if
len
(
wrapped_fns
)
>
0
:
wrap_name
=
add_global
(
'wrap'
,
torch
.
fx
.
wrap
)
wrap_stmts
=
'
\n
'
.
join
([
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
])
else
:
else
:
pass
wrap_stmts
=
''
return
ckpt_regions
def
gen_ckpt_fn_def
(
self
,
label
,
free_vars
:
List
[
str
])
->
str
:
if
self
.
_body_transformer
:
"""
body
=
self
.
_body_transformer
(
body
)
Generate the checkpoint function definition
"""
return
f
"def checkpoint_
{
label
}
(
{
', '
.
join
(
free_vars
)
}
):"
def
gen_ckpt_output
(
self
,
output_vars
:
List
[
str
])
->
str
:
for
name
,
value
in
self
.
additional_globals
():
"""
add_global
(
name
,
value
)
Generate the return statement for checkpoint region
"""
prologue
=
self
.
gen_fn_def
(
free_vars
,
maybe_return_annotation
[
0
])
return
f
"return
{
', '
.
join
(
output_vars
)
}
"
def
gen_ckpt_usage
(
self
,
label
,
input_vars
,
output_vars
):
code
=
''
.
join
(
body
)
code
=
'
\n
'
.
join
(
' '
+
line
for
line
in
code
.
split
(
'
\n
'
))
fn_code
=
f
"""
{
wrap_stmts
}
{
prologue
}
{
code
}
"""
return
PythonCode
(
fn_code
,
globals_
)
else
:
def
python_code_with_activation_checkpoint
(
self
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
"""
"""
Generate the checkpoint function call code text
This method is copied from the _python_code of torch.fx.graph.Graph. Modifications are made so that it can generate
code for activation checkpoint.
"""
"""
outputs
=
', '
.
join
(
output_vars
)
inputs
=
', '
.
join
(
input_vars
)
return
f
'
{
outputs
}
= torch.utils.checkpoint.checkpoint(checkpoint_
{
label
}
,
{
inputs
}
)'
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
free_vars
:
List
[
str
]
=
[]
free_vars
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
globals_
:
Dict
[
str
,
Any
]
=
{}
globals_
:
Dict
[
str
,
Any
]
=
{}
...
@@ -138,45 +428,19 @@ class ActivationCheckpointCodeGen(CodeGen):
...
@@ -138,45 +428,19 @@ class ActivationCheckpointCodeGen(CodeGen):
typename
=
_type_repr
(
o
)
typename
=
_type_repr
(
o
)
# This is a generic type, e.g. typing.List[torch.Tensor]
if
hasattr
(
o
,
'__origin__'
):
if
hasattr
(
o
,
'__origin__'
):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type
=
_origin_type_map
.
get
(
o
.
__origin__
,
o
.
__origin__
)
origin_type
=
_origin_type_map
.
get
(
o
.
__origin__
,
o
.
__origin__
)
origin_typename
=
add_global
(
_type_repr
(
origin_type
),
origin_type
)
origin_typename
=
add_global
(
_type_repr
(
origin_type
),
origin_type
)
if
hasattr
(
o
,
'__args__'
):
# Assign global names for each of the inner type variables.
# Assign global names for each of the inner type variables.
args
=
[
type_repr
(
arg
)
for
arg
in
o
.
__args__
]
args
=
[
type_repr
(
arg
)
for
arg
in
o
.
__args__
]
if
len
(
args
)
==
0
:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python < 3.9
return
origin_typename
return
f
'
{
origin_typename
}
[
{
","
.
join
(
args
)
}
]'
return
f
'
{
origin_typename
}
[
{
","
.
join
(
args
)
}
]'
else
:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9+
return
origin_typename
# Common case: this is a regular module name like 'foo.bar.baz'
# Common case: this is a regular module name like 'foo.bar.baz'
return
add_global
(
typename
,
o
)
return
add_global
(
typename
,
o
)
def
_format_args
(
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Argument
])
->
str
:
def
_get_repr
(
arg
):
# Handle NamedTuples (if it has `_fields`) via add_global.
if
isinstance
(
arg
,
tuple
)
and
hasattr
(
arg
,
'_fields'
):
qualified_name
=
_get_qualified_name
(
type
(
arg
))
global_name
=
add_global
(
qualified_name
,
type
(
arg
))
return
f
"
{
global_name
}{
repr
(
tuple
(
arg
))
}
"
return
repr
(
arg
)
args_s
=
', '
.
join
(
_get_repr
(
a
)
for
a
in
args
)
kwargs_s
=
', '
.
join
(
f
'
{
k
}
=
{
_get_repr
(
v
)
}
'
for
k
,
v
in
kwargs
.
items
())
if
args_s
and
kwargs_s
:
return
f
'
{
args_s
}
,
{
kwargs_s
}
'
return
args_s
or
kwargs_s
# Run through reverse nodes and record the first instance of a use
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# execution order of the program, which we will use to free unused
...
@@ -189,7 +453,7 @@ class ActivationCheckpointCodeGen(CodeGen):
...
@@ -189,7 +453,7 @@ class ActivationCheckpointCodeGen(CodeGen):
node_to_last_use
[
n
]
=
user
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
for
node
in
reversed
(
self
.
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
...
@@ -234,14 +498,6 @@ class ActivationCheckpointCodeGen(CodeGen):
...
@@ -234,14 +498,6 @@ class ActivationCheckpointCodeGen(CodeGen):
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
= '
body
.
append
(
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
= '
f
'
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
'
)
f
'
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
'
)
return
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if
node
.
target
.
__module__
==
'_operator'
and
node
.
target
.
__name__
in
inplace_methods
:
body
.
append
(
f
'
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
; '
f
'
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
])
}
'
)
return
qualified_name
=
_get_qualified_name
(
node
.
target
)
qualified_name
=
_get_qualified_name
(
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
# special case for getattr: node.args could be 2-argument or 3-argument
# special case for getattr: node.args could be 2-argument or 3-argument
...
@@ -271,74 +527,32 @@ class ActivationCheckpointCodeGen(CodeGen):
...
@@ -271,74 +527,32 @@ class ActivationCheckpointCodeGen(CodeGen):
elif
node
.
op
==
'output'
:
elif
node
.
op
==
'output'
:
if
node
.
type
is
not
None
:
if
node
.
type
is
not
None
:
maybe_return_annotation
[
0
]
=
f
" ->
{
type_repr
(
node
.
type
)
}
"
maybe_return_annotation
[
0
]
=
f
" ->
{
type_repr
(
node
.
type
)
}
"
body
.
append
(
self
.
generate_output
(
node
.
args
[
0
]))
if
self
.
_pytree_info
is
None
:
body
.
append
(
f
'return
{
repr
(
node
.
args
[
0
])
}
'
)
else
:
body
.
append
(
f
'return pytree.tree_unflatten(
{
repr
(
node
.
args
[
0
])
}
, self._out_spec)'
)
return
return
raise
NotImplementedError
(
f
'node:
{
node
.
op
}
{
node
.
target
}
'
)
raise
NotImplementedError
(
f
'node:
{
node
.
op
}
{
node
.
target
}
'
)
#########################################
# Modified for activation checkpointing
# Modified for activation checkpointing #
emit_code_with_activation_checkpoint
(
body
,
self
.
nodes
,
emit_node
,
delete_unused_values
)
#########################################
# find the activation checkpoint regions
ckpt_regions
=
self
.
find_ckpt_regions
(
nodes
)
start_idx
=
[
item
[
0
]
for
item
in
ckpt_regions
]
end_idx
=
[
item
[
1
]
for
item
in
ckpt_regions
]
input_vars
=
[]
output_vars
=
[]
within_ckpt_region
=
False
node_list
=
list
(
nodes
)
# find the input and output var names for each region
for
idx
,
(
start
,
end
)
in
enumerate
(
ckpt_regions
):
ckpt_node_list
=
node_list
[
start
:
end
+
1
]
inputs
,
outputs
=
self
.
find_input_and_output_nodes
(
ckpt_node_list
)
input_vars
.
append
(
inputs
)
output_vars
.
append
(
outputs
)
# append code text to body
for
idx
,
node
in
enumerate
(
node_list
):
# if this is the first node of the ckpt region
# append the ckpt function defition
if
idx
in
start_idx
:
label
=
start_idx
.
index
(
idx
)
ckpt_fn_def
=
self
.
gen_ckpt_fn_def
(
label
,
input_vars
[
label
])
body
.
append
(
f
'
{
ckpt_fn_def
}
\n
'
)
within_ckpt_region
=
True
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
emit_node
(
node
)
# add indentation to the emmited node
if
within_ckpt_region
:
body
[
-
1
]
=
' '
+
body
[
-
1
]
# delete unused values
delete_unused_values
(
node
)
if
idx
in
end_idx
:
# if this is the last node of the ckpt region
# generate return statement
label
=
end_idx
.
index
(
idx
)
return_statement
=
self
.
gen_ckpt_output
(
output_vars
[
label
])
return_statement
=
f
'
{
return_statement
}
\n
'
body
.
append
(
return_statement
)
# generate checkpoint function call in a new line
usage
=
self
.
gen_ckpt_usage
(
label
,
input_vars
[
label
],
output_vars
[
label
])
usage
+=
'
\n
'
body
.
append
(
usage
)
within_ckpt_region
=
False
#######################################################
# Code Change For Activation Checkpointing Stops Here #
#######################################################
if
len
(
body
)
==
0
:
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
# single pass statement
body
.
append
(
'pass
\n
'
)
body
.
append
(
'pass
\n
'
)
if
self
.
_pytree_info
is
not
None
:
orig_args
=
self
.
_pytree_info
.
orig_args
has_orig_self
=
(
orig_args
[
0
]
==
'self'
)
if
has_orig_self
:
free_vars
.
insert
(
0
,
'self'
)
if
len
(
free_vars
)
>
0
:
# pytree has placeholders in it
body
.
insert
(
0
,
f
"
{
', '
.
join
(
free_vars
)
}
, = fx_pytree.tree_flatten_spec([
{
', '
.
join
(
orig_args
)
}
], self._in_spec)
\n
"
)
else
:
orig_args
=
free_vars
if
len
(
wrapped_fns
)
>
0
:
if
len
(
wrapped_fns
)
>
0
:
wrap_name
=
add_global
(
'wrap'
,
torch
.
fx
.
wrap
)
wrap_name
=
add_global
(
'wrap'
,
torch
.
fx
.
wrap
)
...
@@ -346,19 +560,15 @@ class ActivationCheckpointCodeGen(CodeGen):
...
@@ -346,19 +560,15 @@ class ActivationCheckpointCodeGen(CodeGen):
else
:
else
:
wrap_stmts
=
''
wrap_stmts
=
''
if
self
.
_body_transformer
:
# If the original function didn't have self as its first argument, we
body
=
self
.
_body_transformer
(
body
)
# would have added it.
if
len
(
orig_args
)
==
0
or
orig_args
[
0
]
!=
'self'
:
for
name
,
value
in
self
.
additional_globals
():
orig_args
.
insert
(
0
,
'self'
)
add_global
(
name
,
value
)
prologue
=
self
.
gen_fn_def
(
free_vars
,
maybe_return_annotation
[
0
])
code
=
''
.
join
(
body
)
code
=
''
.
join
(
body
)
code
=
'
\n
'
.
join
(
' '
+
line
for
line
in
code
.
split
(
'
\n
'
))
code
=
'
\n
'
.
join
(
' '
+
line
for
line
in
code
.
split
(
'
\n
'
))
fn_code
=
f
"""
fn_code
=
f
"""
{
wrap_stmts
}
{
wrap_stmts
}
{
prologue
}
def forward(
{
', '
.
join
(
orig_args
)
}
)
{
maybe_return_annotation
[
0
]
}
:
{
code
}
"""
{
code
}
"""
return
PythonCode
(
fn_code
,
globals_
)
return
PythonCode
(
fn_code
,
globals_
)
tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py
View file @
cd063ac3
...
@@ -6,8 +6,11 @@ from colossalai.fx import ColoTracer
...
@@ -6,8 +6,11 @@ from colossalai.fx import ColoTracer
try
:
try
:
from
colossalai.fx.codegen
import
ActivationCheckpointCodeGen
from
colossalai.fx.codegen
import
ActivationCheckpointCodeGen
with_codegen
=
True
except
:
except
:
pass
# fall back to older pytorch version
from
colossalai.fx.codegen
import
python_code_with_activation_checkpoint
with_codegen
=
False
class
MLP
(
torch
.
nn
.
Module
):
class
MLP
(
torch
.
nn
.
Module
):
...
@@ -35,7 +38,7 @@ class MyModule(torch.nn.Module):
...
@@ -35,7 +38,7 @@ class MyModule(torch.nn.Module):
return
y1
+
y2
+
y3
+
y4
return
y1
+
y2
+
y3
+
y4
@
pytest
.
mark
.
skip
(
"torch 1.12 is required"
)
@
pytest
.
mark
.
skip
if
(
not
with_codegen
,
reason
=
'torch version is lower than 1.12.0'
)
def
test_act_ckpt_codegen
():
def
test_act_ckpt_codegen
():
# build model and run forward
# build model and run forward
model
=
MyModule
()
model
=
MyModule
()
...
@@ -65,5 +68,37 @@ def test_act_ckpt_codegen():
...
@@ -65,5 +68,37 @@ def test_act_ckpt_codegen():
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
@
pytest
.
mark
.
skipif
(
with_codegen
,
reason
=
'torch version is equal to or higher than 1.12.0'
)
def
test_act_ckpt_python_code_torch11
():
# build model and run forward
model
=
MyModule
()
data
=
torch
.
rand
(
4
,
4
)
non_fx_out
=
model
(
data
)
# trace the module and replace codegen
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
graph
=
tracer
.
trace
(
model
)
# replace a bound method of an object
graph
.
_python_code
=
python_code_with_activation_checkpoint
.
__get__
(
graph
)
# check ops are annotated with ckpt
ckpt_nodes
=
[
'mlp1_linear1'
,
'mlp1_linear1_1'
,
'mlp2_linear1'
,
'mlp2_linear1_1'
]
for
node
in
graph
.
nodes
:
if
node
.
name
in
ckpt_nodes
:
assert
hasattr
(
node
,
'activation_checkpoint'
)
# assert checkpoint function will be generated
code
=
graph
.
python_code
(
'self'
).
src
assert
'checkpoint_0'
in
code
and
'checkpoint_1'
in
code
# recompile and verify the outputs are consistent
gm
=
GraphModule
(
model
,
graph
)
gm
.
recompile
()
fx_out
=
gm
(
data
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_act_ckpt_codegen
()
test_act_ckpt_codegen
()
test_act_ckpt_python_code_torch11
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment