Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
441d584e
Unverified
Commit
441d584e
authored
Nov 08, 2022
by
Super Daniel
Committed by
GitHub
Nov 08, 2022
Browse files
[fx] add a symbolic_trace api. (#1812)
* [fx] add a symbolic_trace api. * [fx] fix import errors.
parent
350ccc04
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
90 additions
and
73 deletions
+90
-73
colossalai/fx/__init__.py
colossalai/fx/__init__.py
+4
-4
colossalai/fx/tracer/__init__.py
colossalai/fx/tracer/__init__.py
+1
-0
colossalai/fx/tracer/_symbolic_trace.py
colossalai/fx/tracer/_symbolic_trace.py
+58
-0
tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py
tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py
+2
-7
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
+1
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
+1
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
+5
-14
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
+1
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
+1
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
+1
-1
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
+5
-10
tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
..._fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
+2
-6
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
...t_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
+2
-11
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
...est_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
+3
-9
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
...t_tracer/test_torchvision_model/test_torchvision_model.py
+3
-7
No files found.
colossalai/fx/__init__.py
View file @
441d584e
from
._compatibility
import
compatibility
,
is_compatible_with_meta
from
.graph_module
import
ColoGraphModule
from
.passes
import
MetaInfoProp
from
.tracer
import
ColoTracer
,
meta_trace
from
._compatibility
import
compatibility
,
is_compatible_with_meta
from
.graph_module
import
ColoGraphModule
from
.passes
import
MetaInfoProp
from
.tracer
import
ColoTracer
,
meta_trace
,
symbolic_trace
colossalai/fx/tracer/__init__.py
View file @
441d584e
from
colossalai.fx.tracer.meta_patch.patched_function.python_ops
import
operator_getitem
from
._meta_trace
import
meta_trace
from
._symbolic_trace
import
symbolic_trace
from
.tracer
import
ColoTracer
colossalai/fx/tracer/_symbolic_trace.py
0 → 100644
View file @
441d584e
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
import
torch
from
colossalai.fx
import
ColoGraphModule
from
colossalai.fx._compatibility
import
compatibility
from
.tracer
import
ColoTracer
@
compatibility
(
is_backward_compatible
=
True
)
def
symbolic_trace
(
root
:
Union
[
torch
.
nn
.
Module
,
Callable
[...,
Any
]],
concrete_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
ColoGraphModule
:
"""
Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
constructed by recording operations seen while tracing through ``root``.
With ``meta_args`` and ``concrete_args``, we can trace the model that are untraceable subject to control flow.
If specified using ``meta_args`` only, the tracing can be done ahead of time.
Note that both ``meta_args`` and ``concrete_args`` are kwargs, which contains the key of the argument's names
and the value of the argument's values.
Uses:
>>> model = ...
# if this works
>>> gm = symbolic_trace(model)
# else try this
>>> gm = symbolic_trace(model, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
# else try this
>>> gm = symbolic_trace(model, concrete_args={'x': torch.rand(1, 3, 224, 224)})
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
into a Graph representation.
concrete_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized. Defaults to None.
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
Defaults to None.
Returns:
ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``.
Warnings:
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
root
,
concrete_args
,
meta_args
)
name
=
(
root
.
__class__
.
__name__
if
isinstance
(
root
,
torch
.
nn
.
Module
)
else
root
.
__name__
)
return
ColoGraphModule
(
tracer
.
root
,
graph
,
name
)
tests/test_fx/test_tracer/test_hf_model/utils.py
→
tests/test_fx/test_tracer/test_hf_model/
hf_tracer_
utils.py
View file @
441d584e
...
...
@@ -3,24 +3,19 @@ from numpy import isin
from
torch.fx
import
GraphModule
from
torch.utils._pytree
import
tree_flatten
from
colossalai.fx
import
ColoT
race
r
from
colossalai.fx
import
symbolic_t
race
def
trace_model_and_compare_output
(
model
,
data_gen
):
# must turn on eval mode to ensure the output is consistent
model
.
eval
()
# make sure that the model is traceable
tracer
=
ColoTracer
()
try
:
kwargs
=
data_gen
()
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
kwargs
.
items
()}
g
raph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
g
m
=
symbolic_
trace
(
model
,
meta_args
=
meta_args
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Failed to trace
{
model
.
__class__
.
__name__
}
, error:
{
e
}
"
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
# run forward
inputs
=
data_gen
()
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
View file @
441d584e
import
pytest
import
torch
import
transformers
from
utils
import
trace_model_and_compare_output
from
hf_tracer_
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
2
SEQ_LENGTH
=
16
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
View file @
441d584e
import
pytest
import
torch
import
transformers
from
utils
import
trace_model_and_compare_output
from
hf_tracer_
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
2
SEQ_LENGTH
=
16
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
View file @
441d584e
import
pytest
import
torch
from
torch.fx
import
GraphModule
from
utils
import
trace_model_and_compare_output
import
transformers
from
colossalai.fx
import
ColoTracer
from
hf_tracer_utils
import
trace_model_and_compare_output
from
colossalai.fx
import
symbolic_trace
try
:
import
diffusers
...
...
@@ -32,11 +31,7 @@ def test_vae():
model
=
model_cls
()
sample
=
torch
.
zeros
(
LATENTS_SHAPE
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
root
=
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
=
symbolic_trace
(
model
)
model
.
eval
()
gm
.
eval
()
...
...
@@ -98,11 +93,7 @@ def test_unet():
model
=
model_cls
()
sample
=
torch
.
zeros
(
LATENTS_SHAPE
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
root
=
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
=
symbolic_trace
(
model
)
model
.
eval
()
gm
.
eval
()
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
View file @
441d584e
import
pytest
import
torch
import
transformers
from
utils
import
trace_model_and_compare_output
from
hf_tracer_
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
1
SEQ_LENGTH
=
16
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
View file @
441d584e
import
pytest
import
torch
import
transformers
from
utils
import
trace_model_and_compare_output
from
hf_tracer_
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
1
SEQ_LENGTH
=
16
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
View file @
441d584e
import
pytest
import
torch
import
transformers
from
utils
import
trace_model_and_compare_output
from
hf_tracer_
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
1
SEQ_LENGTH
=
16
...
...
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
View file @
441d584e
import
pytest
import
timm.models
as
tm
import
torch
from
torch.fx
import
GraphModule
from
colossalai.fx
import
ColoT
race
r
from
colossalai.fx
import
symbolic_t
race
def
trace_and_compare
(
model_cls
,
tracer
,
data
,
meta_args
=
None
):
def
trace_and_compare
(
model_cls
,
data
,
meta_args
=
None
):
# trace
model
=
model_cls
()
...
...
@@ -15,9 +14,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
model
.
eval
()
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
=
symbolic_trace
(
model
,
meta_args
=
meta_args
)
# run forward
with
torch
.
no_grad
():
...
...
@@ -49,11 +46,10 @@ def test_timm_models_without_control_flow():
tm
.
deit_base_distilled_patch16_224
,
]
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
for
model_cls
in
MODEL_LIST
:
trace_and_compare
(
model_cls
,
tracer
,
data
)
trace_and_compare
(
model_cls
,
data
)
def
test_timm_models_with_control_flow
():
...
...
@@ -64,13 +60,12 @@ def test_timm_models_with_control_flow():
tm
.
swin_transformer
.
swin_base_patch4_window7_224
]
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
meta_args
=
{
'x'
:
data
.
to
(
'meta'
)}
for
model_cls
in
MODEL_LIST_WITH_CONTROL_FLOW
:
trace_and_compare
(
model_cls
,
tracer
,
data
,
meta_args
)
trace_and_compare
(
model_cls
,
data
,
meta_args
)
if
__name__
==
'__main__'
:
...
...
tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
View file @
441d584e
import
torch
from
torch.fx
import
GraphModule
,
Tracer
from
colossalai.fx
import
ColoT
race
r
from
colossalai.fx
import
symbolic_t
race
def
trace_and_compare
(
model
,
data_gen
,
need_meta
=
False
,
need_concrete
=
False
,
kwargs_transform
=
False
):
data
=
data_gen
()
concrete_args
=
data
if
need_concrete
else
{}
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
data
.
items
()}
if
need_meta
else
{}
tracer
=
ColoTracer
()
model
.
eval
()
graph
=
tracer
.
trace
(
root
=
model
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
=
symbolic_trace
(
model
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
with
torch
.
no_grad
():
non_fx_out
=
model
(
**
data
)
...
...
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
View file @
441d584e
import
pytest
import
torch
from
colossalai.fx.tracer
import
meta_patch
from
colossalai.fx.tracer.meta_patch.patched_function
import
python_ops
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx
import
symbolic_trace
try
:
from
torchrec.models
import
deepfm
...
...
@@ -14,8 +12,6 @@ try:
except
ImportError
:
NOT_TORCHREC
=
True
from
torch.fx
import
GraphModule
BATCH
=
2
SHAPE
=
10
...
...
@@ -43,9 +39,6 @@ def test_torchrec_deepfm_models():
# Dense Features
features
=
torch
.
rand
((
BATCH
,
SHAPE
))
# Tracer
tracer
=
ColoTracer
()
for
model_cls
in
MODEL_LIST
:
# Initializing model
if
model_cls
==
deepfm
.
DenseArch
:
...
...
@@ -60,9 +53,7 @@ def test_torchrec_deepfm_models():
model
=
model_cls
(
ebc
)
# Setup GraphModule
graph
=
tracer
.
trace
(
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
=
symbolic_trace
(
model
)
model
.
eval
()
gm
.
eval
()
...
...
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
View file @
441d584e
import
torch
from
colossalai.fx
.tracer.tracer
import
ColoT
race
r
from
colossalai.fx
import
symbolic_t
race
try
:
from
torchrec.models
import
dlrm
...
...
@@ -12,7 +12,6 @@ except ImportError:
NOT_TORCHREC
=
True
import
pytest
from
torch.fx
import
GraphModule
BATCH
=
2
SHAPE
=
10
...
...
@@ -51,8 +50,6 @@ def test_torchrec_dlrm_models():
# Sparse Features
sparse_features
=
torch
.
rand
((
BATCH
,
len
(
keys
),
SHAPE
))
# Tracer
tracer
=
ColoTracer
()
for
model_cls
in
MODEL_LIST
:
# Initializing model
...
...
@@ -77,12 +74,9 @@ def test_torchrec_dlrm_models():
# Setup GraphModule
if
model_cls
==
dlrm
.
InteractionV2Arch
:
concrete_args
=
{
"dense_features"
:
dense_features
,
"sparse_features"
:
sparse_features
}
g
raph
=
tracer
.
trace
(
model
,
concrete_args
=
concrete_args
)
g
m
=
symbolic_
trace
(
model
,
concrete_args
=
concrete_args
)
else
:
graph
=
tracer
.
trace
(
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
=
symbolic_trace
(
model
)
model
.
eval
()
gm
.
eval
()
...
...
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
View file @
441d584e
...
...
@@ -2,8 +2,8 @@ import torch
import
torchvision
import
torchvision.models
as
tm
from
packaging
import
version
from
colossalai.fx
import
ColoTracer
from
torch
.fx
import
GraphModul
e
from
colossalai
.fx
import
symbolic_trac
e
def
test_torchvision_models
():
...
...
@@ -20,7 +20,6 @@ def test_torchvision_models():
torch
.
backends
.
cudnn
.
deterministic
=
True
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
for
model_cls
in
MODEL_LIST
:
...
...
@@ -30,10 +29,7 @@ def test_torchvision_models():
else
:
model
=
model_cls
()
graph
=
tracer
.
trace
(
root
=
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
=
symbolic_trace
(
model
)
model
.
eval
()
gm
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment