Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c41e59e5
Unverified
Commit
c41e59e5
authored
Jan 11, 2023
by
Super Daniel
Committed by
GitHub
Jan 11, 2023
Browse files
[fx] allow native ckpt trace and codegen. (#2438)
parent
41429b9b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
23 deletions
+37
-23
colossalai/fx/graph_module.py
colossalai/fx/graph_module.py
+10
-5
colossalai/fx/tracer/_symbolic_trace.py
colossalai/fx/tracer/_symbolic_trace.py
+2
-1
colossalai/fx/tracer/experimental.py
colossalai/fx/tracer/experimental.py
+25
-17
No files found.
colossalai/fx/graph_module.py
View file @
c41e59e5
import
os
import
warnings
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Type
,
Union
import
torch
import
torch.nn
as
nn
from
torch.nn.modules.module
import
_addindent
from
typing
import
Type
,
Dict
,
List
,
Any
,
Union
,
Optional
,
Set
from
pathlib
import
Path
try
:
from
torch.fx.graph_module
import
GraphModule
,
_EvalCacheLoader
,
_WrappedCall
,
_exec_with_source
,
_forward_from_src
from
torch.fx.graph
import
Graph
,
_PyTreeCodeGen
,
_is_from_torch
,
_custom_builtins
,
PythonCode
from
torch.fx.graph
import
Graph
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_PyTreeCodeGen
from
torch.fx.graph_module
import
GraphModule
,
_EvalCacheLoader
,
_exec_with_source
,
_forward_from_src
,
_WrappedCall
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
ActivationCheckpointCodeGen
COLOGM
=
True
except
:
from
torch.fx.graph_module
import
GraphModule
from
torch.fx.graph
import
Graph
from
torch.fx.graph_module
import
GraphModule
COLOGM
=
False
if
COLOGM
:
...
...
@@ -19,6 +23,7 @@ if COLOGM:
class
ColoGraphModule
(
GraphModule
):
def
__init__
(
self
,
root
:
Union
[
torch
.
nn
.
Module
,
Dict
[
str
,
Any
]],
graph
:
Graph
,
class_name
:
str
=
'GraphModule'
):
graph
.
set_codegen
(
ActivationCheckpointCodeGen
())
super
().
__init__
(
root
,
graph
,
class_name
)
def
bind
(
self
,
ckpt_def
,
globals
):
...
...
colossalai/fx/tracer/_symbolic_trace.py
View file @
c41e59e5
...
...
@@ -13,6 +13,7 @@ def symbolic_trace(
root
:
Union
[
torch
.
nn
.
Module
,
Callable
[...,
Any
]],
concrete_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
trace_act_ckpt
=
False
,
)
->
ColoGraphModule
:
"""
Symbolic tracing API
...
...
@@ -49,6 +50,6 @@ def symbolic_trace(
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
graph
=
ColoTracer
().
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
graph
=
ColoTracer
(
trace_act_ckpt
=
trace_act_ckpt
).
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
name
=
root
.
__class__
.
__name__
if
isinstance
(
root
,
torch
.
nn
.
Module
)
else
root
.
__name__
return
ColoGraphModule
(
root
,
graph
,
name
)
colossalai/fx/tracer/experimental.py
View file @
c41e59e5
import
enum
import
functools
import
operator
import
inspect
import
operator
from
contextlib
import
contextmanager
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
...
...
@@ -286,7 +286,6 @@ class ColoTracer(Tracer):
self
.
graph
.
lint
()
return
self
.
graph
@
contextmanager
def
trace_activation_checkpoint
(
self
,
enabled
:
bool
):
if
enabled
:
...
...
@@ -316,7 +315,6 @@ class ColoTracer(Tracer):
# recover the checkpoint function upon exit
torch
.
utils
.
checkpoint
.
CheckpointFunction
=
orig_ckpt_func
def
_post_check
(
self
,
non_concrete_arg_names
:
Set
[
str
]):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
...
...
@@ -385,18 +383,23 @@ def symbolic_trace(
root
:
Union
[
torch
.
nn
.
Module
,
Callable
[...,
Any
]],
concrete_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
trace_act_ckpt
=
False
,
)
->
ColoGraphModule
:
if
is_compatible_with_meta
():
if
meta_args
is
not
None
:
root
.
to
(
default_device
())
wrap_fn
=
lambda
x
:
MetaTensor
(
x
,
fake_device
=
default_device
())
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
graph
=
ColoTracer
().
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
tree_map
(
wrap_fn
,
meta_args
))
graph
=
ColoTracer
(
trace_act_ckpt
=
trace_act_ckpt
).
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
tree_map
(
wrap_fn
,
meta_args
))
root
.
cpu
()
else
:
graph
=
Tracer
().
trace
(
root
,
concrete_args
=
concrete_args
)
else
:
from
.tracer
import
ColoTracer
as
OrigColoTracer
graph
=
OrigColoTracer
().
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
graph
=
OrigColoTracer
(
trace_act_ckpt
=
trace_act_ckpt
).
trace
(
root
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
name
=
root
.
__class__
.
__name__
if
isinstance
(
root
,
torch
.
nn
.
Module
)
else
root
.
__name__
return
ColoGraphModule
(
root
,
graph
,
name
)
...
...
@@ -471,11 +474,11 @@ def meta_prop_pass(gm: ColoGraphModule,
node
.
_meta_data
=
_meta_data_computing
(
meta_args
,
concrete_args
,
root
,
node
.
op
,
node
.
target
,
node
.
args
,
node
.
kwargs
)
def
_meta_data_computing
(
meta_args
,
concrete_args
,
root
,
kind
,
target
,
args
,
kwargs
):
unwrap_fn
=
lambda
n
:
n
.
_meta_data
if
isinstance
(
n
,
Node
)
else
n
if
kind
==
'placeholder'
:
meta_out
=
meta_args
[
target
]
if
target
in
meta_args
else
concrete_args
.
get
(
_truncate_suffix
(
target
),
None
)
meta_out
=
meta_args
[
target
]
if
target
in
meta_args
else
concrete_args
.
get
(
_truncate_suffix
(
target
),
None
)
elif
kind
==
'get_attr'
:
attr_itr
=
root
atoms
=
target
.
split
(
"."
)
...
...
@@ -490,7 +493,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
else
:
if
target
not
in
_TensorPropertyMethod
:
meta_out
=
getattr
(
unwrap_fn
(
args
[
0
]),
target
)(
*
tree_map
(
unwrap_fn
,
args
[
1
:]),
**
tree_map
(
unwrap_fn
,
kwargs
))
**
tree_map
(
unwrap_fn
,
kwargs
))
elif
kind
==
'call_module'
:
mod
=
root
.
get_submodule
(
target
)
meta_out
=
mod
.
forward
(
*
tree_map
(
unwrap_fn
,
args
),
**
tree_map
(
unwrap_fn
,
kwargs
))
...
...
@@ -498,6 +501,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
meta_out
=
None
return
meta_out
def
_meta_data_computing_v0
(
meta_args
,
root
,
kind
,
target
,
args
,
kwargs
):
if
kind
==
"placeholder"
and
target
in
meta_args
and
meta_args
[
target
].
is_meta
:
meta_out
=
meta_args
[
target
]
...
...
@@ -568,7 +572,7 @@ def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
return
meta_out
def
bias_addition_pass
(
gm
:
ColoGraphModule
,
root_model
:
torch
.
nn
.
Module
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
def
bias_addition_pass
(
gm
:
ColoGraphModule
,
root_model
:
torch
.
nn
.
Module
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
result_graph
=
Graph
()
value_remap
=
{}
unwrap_fn
=
lambda
n
:
n
.
_meta_data
if
isinstance
(
n
,
Node
)
else
n
...
...
@@ -601,20 +605,24 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
if
target
==
torch
.
nn
.
functional
.
linear
:
if
'bias'
in
kwargs
and
kwargs
[
'bias'
]
is
not
None
:
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
handle
=
bias_addition_function
.
get
(
target
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
else
:
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
handle
=
bias_addition_function
.
get
(
target
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
elif
bias_addition_function
.
has
(
target
.
__name__
):
# use name for some builtin op like @ (matmul)
function_to_substitute
=
func_to_func_dict
[
target
]
handle
=
bias_addition_function
.
get
(
target
.
__name__
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
handle
=
bias_addition_function
.
get
(
target
.
__name__
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
elif
kind
==
"call_method"
:
method
=
getattr
(
args_metas
[
0
].
__class__
,
target
)
if
bias_addition_method
.
has
(
method
):
function_to_substitute
=
method_to_func_dict
[
method
]
handle
=
bias_addition_method
.
get
(
method
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
handle
=
bias_addition_method
.
get
(
method
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
elif
kind
==
"call_module"
:
# if not hasattr(self, "orig_forward"):
...
...
@@ -623,20 +631,20 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
mod_type
=
type
(
mod
)
if
bias_addition_module
.
has
(
mod_type
)
and
mod
.
bias
is
not
None
:
function_to_substitute
=
module_to_func_dict
[
mod_type
]
handle
=
bias_addition_module
.
get
(
mod_type
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
handle
=
bias_addition_module
.
get
(
mod_type
)(
tracer
,
target
,
args_proxy
,
kwargs_proxy
,
function_to_substitute
)
if
handle
is
not
None
:
handle
.
generate
()
for
node_inserted
in
tracer
.
graph
.
nodes
:
value_remap
[
node_inserted
]
=
result_graph
.
node_copy
(
node_inserted
,
lambda
n
:
value_remap
[
n
])
value_remap
[
node_inserted
]
=
result_graph
.
node_copy
(
node_inserted
,
lambda
n
:
value_remap
[
n
])
last_node
=
value_remap
[
node_inserted
]
value_remap
[
orig_node
]
=
last_node
else
:
value_remap
[
orig_node
]
=
result_graph
.
node_copy
(
orig_node
,
lambda
n
:
value_remap
[
n
])
value_remap
[
orig_node
]
=
result_graph
.
node_copy
(
orig_node
,
lambda
n
:
value_remap
[
n
])
del
tracer
gm
.
graph
=
result_graph
gm
.
recompile
()
meta_prop_pass
(
gm
,
root_model
,
meta_args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment