Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
441d584e
Unverified
Commit
441d584e
authored
Nov 08, 2022
by
Super Daniel
Committed by
GitHub
Nov 08, 2022
Browse files
[fx] add a symbolic_trace api. (#1812)
* [fx] add a symbolic_trace api. * [fx] fix import errors.
parent
350ccc04
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
90 additions
and
73 deletions
+90
-73
colossalai/fx/__init__.py
colossalai/fx/__init__.py
+4
-4
colossalai/fx/tracer/__init__.py
colossalai/fx/tracer/__init__.py
+1
-0
colossalai/fx/tracer/_symbolic_trace.py
colossalai/fx/tracer/_symbolic_trace.py
+58
-0
tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py
tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py
+2
-7
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
+1
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
+1
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
+5
-14
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
+1
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
+1
-1
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
+1
-1
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
+5
-10
tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
..._fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
+2
-6
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
...t_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
+2
-11
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
...est_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
+3
-9
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
...t_tracer/test_torchvision_model/test_torchvision_model.py
+3
-7
No files found.
colossalai/fx/__init__.py
View file @
441d584e
from
._compatibility
import
compatibility
,
is_compatible_with_meta
from
._compatibility
import
compatibility
,
is_compatible_with_meta
from
.graph_module
import
ColoGraphModule
from
.graph_module
import
ColoGraphModule
from
.passes
import
MetaInfoProp
from
.passes
import
MetaInfoProp
from
.tracer
import
ColoTracer
,
meta_trace
from
.tracer
import
ColoTracer
,
meta_trace
,
symbolic_trace
colossalai/fx/tracer/__init__.py
View file @
441d584e
from
colossalai.fx.tracer.meta_patch.patched_function.python_ops
import
operator_getitem
from
colossalai.fx.tracer.meta_patch.patched_function.python_ops
import
operator_getitem
from
._meta_trace
import
meta_trace
from
._meta_trace
import
meta_trace
from
._symbolic_trace
import
symbolic_trace
from
.tracer
import
ColoTracer
from
.tracer
import
ColoTracer
colossalai/fx/tracer/_symbolic_trace.py
0 → 100644
View file @
441d584e
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Union
import
torch
from
colossalai.fx
import
ColoGraphModule
from
colossalai.fx._compatibility
import
compatibility
from
.tracer
import
ColoTracer
@
compatibility
(
is_backward_compatible
=
True
)
def
symbolic_trace
(
root
:
Union
[
torch
.
nn
.
Module
,
Callable
[...,
Any
]],
concrete_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
meta_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
ColoGraphModule
:
"""
Symbolic tracing API
Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule``
constructed by recording operations seen while tracing through ``root``.
With ``meta_args`` and ``concrete_args``, we can trace the model that are untraceable subject to control flow.
If specified using ``meta_args`` only, the tracing can be done ahead of time.
Note that both ``meta_args`` and ``concrete_args`` are kwargs, which contains the key of the argument's names
and the value of the argument's values.
Uses:
>>> model = ...
# if this works
>>> gm = symbolic_trace(model)
# else try this
>>> gm = symbolic_trace(model, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')})
# else try this
>>> gm = symbolic_trace(model, concrete_args={'x': torch.rand(1, 3, 224, 224)})
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted
into a Graph representation.
concrete_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized. Defaults to None.
meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``.
Defaults to None.
Returns:
ColoGraphModule: A ``ColoGraphModule`` created from the recorded operations from ``root``.
Warnings:
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
root
,
concrete_args
,
meta_args
)
name
=
(
root
.
__class__
.
__name__
if
isinstance
(
root
,
torch
.
nn
.
Module
)
else
root
.
__name__
)
return
ColoGraphModule
(
tracer
.
root
,
graph
,
name
)
tests/test_fx/test_tracer/test_hf_model/utils.py
→
tests/test_fx/test_tracer/test_hf_model/
hf_tracer_
utils.py
View file @
441d584e
...
@@ -3,24 +3,19 @@ from numpy import isin
...
@@ -3,24 +3,19 @@ from numpy import isin
from
torch.fx
import
GraphModule
from
torch.fx
import
GraphModule
from
torch.utils._pytree
import
tree_flatten
from
torch.utils._pytree
import
tree_flatten
from
colossalai.fx
import
ColoT
race
r
from
colossalai.fx
import
symbolic_t
race
def
trace_model_and_compare_output
(
model
,
data_gen
):
def
trace_model_and_compare_output
(
model
,
data_gen
):
# must turn on eval mode to ensure the output is consistent
# must turn on eval mode to ensure the output is consistent
model
.
eval
()
model
.
eval
()
# make sure that the model is traceable
tracer
=
ColoTracer
()
try
:
try
:
kwargs
=
data_gen
()
kwargs
=
data_gen
()
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
kwargs
.
items
()}
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
kwargs
.
items
()}
g
raph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
g
m
=
symbolic_
trace
(
model
,
meta_args
=
meta_args
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
RuntimeError
(
f
"Failed to trace
{
model
.
__class__
.
__name__
}
, error:
{
e
}
"
)
raise
RuntimeError
(
f
"Failed to trace
{
model
.
__class__
.
__name__
}
, error:
{
e
}
"
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
# run forward
# run forward
inputs
=
data_gen
()
inputs
=
data_gen
()
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
View file @
441d584e
import
pytest
import
pytest
import
torch
import
torch
import
transformers
import
transformers
from
utils
import
trace_model_and_compare_output
from
hf_tracer_
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
2
BATCH_SIZE
=
2
SEQ_LENGTH
=
16
SEQ_LENGTH
=
16
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
View file @
441d584e
import
pytest
import
pytest
import
torch
import
torch
import
transformers
import
transformers
from
utils
import
trace_model_and_compare_output
from
hf_tracer_
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
2
BATCH_SIZE
=
2
SEQ_LENGTH
=
16
SEQ_LENGTH
=
16
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
View file @
441d584e
import
pytest
import
pytest
import
torch
import
torch
from
torch.fx
import
GraphModule
from
utils
import
trace_model_and_compare_output
import
transformers
import
transformers
from
colossalai.fx
import
ColoTracer
from
hf_tracer_utils
import
trace_model_and_compare_output
from
colossalai.fx
import
symbolic_trace
try
:
try
:
import
diffusers
import
diffusers
...
@@ -32,11 +31,7 @@ def test_vae():
...
@@ -32,11 +31,7 @@ def test_vae():
model
=
model_cls
()
model
=
model_cls
()
sample
=
torch
.
zeros
(
LATENTS_SHAPE
)
sample
=
torch
.
zeros
(
LATENTS_SHAPE
)
tracer
=
ColoTracer
()
gm
=
symbolic_trace
(
model
)
graph
=
tracer
.
trace
(
root
=
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
model
.
eval
()
model
.
eval
()
gm
.
eval
()
gm
.
eval
()
...
@@ -98,11 +93,7 @@ def test_unet():
...
@@ -98,11 +93,7 @@ def test_unet():
model
=
model_cls
()
model
=
model_cls
()
sample
=
torch
.
zeros
(
LATENTS_SHAPE
)
sample
=
torch
.
zeros
(
LATENTS_SHAPE
)
tracer
=
ColoTracer
()
gm
=
symbolic_trace
(
model
)
graph
=
tracer
.
trace
(
root
=
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
model
.
eval
()
model
.
eval
()
gm
.
eval
()
gm
.
eval
()
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
View file @
441d584e
import
pytest
import
pytest
import
torch
import
torch
import
transformers
import
transformers
from
utils
import
trace_model_and_compare_output
from
hf_tracer_
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
1
BATCH_SIZE
=
1
SEQ_LENGTH
=
16
SEQ_LENGTH
=
16
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
View file @
441d584e
import
pytest
import
pytest
import
torch
import
torch
import
transformers
import
transformers
from
utils
import
trace_model_and_compare_output
from
hf_tracer_
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
1
BATCH_SIZE
=
1
SEQ_LENGTH
=
16
SEQ_LENGTH
=
16
...
...
tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
View file @
441d584e
import
pytest
import
pytest
import
torch
import
torch
import
transformers
import
transformers
from
utils
import
trace_model_and_compare_output
from
hf_tracer_
utils
import
trace_model_and_compare_output
BATCH_SIZE
=
1
BATCH_SIZE
=
1
SEQ_LENGTH
=
16
SEQ_LENGTH
=
16
...
...
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
View file @
441d584e
import
pytest
import
pytest
import
timm.models
as
tm
import
timm.models
as
tm
import
torch
import
torch
from
torch.fx
import
GraphModule
from
colossalai.fx
import
ColoT
race
r
from
colossalai.fx
import
symbolic_t
race
def
trace_and_compare
(
model_cls
,
tracer
,
data
,
meta_args
=
None
):
def
trace_and_compare
(
model_cls
,
data
,
meta_args
=
None
):
# trace
# trace
model
=
model_cls
()
model
=
model_cls
()
...
@@ -15,9 +14,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
...
@@ -15,9 +14,7 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
model
.
eval
()
model
.
eval
()
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
gm
=
symbolic_trace
(
model
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
# run forward
# run forward
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -49,11 +46,10 @@ def test_timm_models_without_control_flow():
...
@@ -49,11 +46,10 @@ def test_timm_models_without_control_flow():
tm
.
deit_base_distilled_patch16_224
,
tm
.
deit_base_distilled_patch16_224
,
]
]
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
for
model_cls
in
MODEL_LIST
:
for
model_cls
in
MODEL_LIST
:
trace_and_compare
(
model_cls
,
tracer
,
data
)
trace_and_compare
(
model_cls
,
data
)
def
test_timm_models_with_control_flow
():
def
test_timm_models_with_control_flow
():
...
@@ -64,13 +60,12 @@ def test_timm_models_with_control_flow():
...
@@ -64,13 +60,12 @@ def test_timm_models_with_control_flow():
tm
.
swin_transformer
.
swin_base_patch4_window7_224
tm
.
swin_transformer
.
swin_base_patch4_window7_224
]
]
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
meta_args
=
{
'x'
:
data
.
to
(
'meta'
)}
meta_args
=
{
'x'
:
data
.
to
(
'meta'
)}
for
model_cls
in
MODEL_LIST_WITH_CONTROL_FLOW
:
for
model_cls
in
MODEL_LIST_WITH_CONTROL_FLOW
:
trace_and_compare
(
model_cls
,
tracer
,
data
,
meta_args
)
trace_and_compare
(
model_cls
,
data
,
meta_args
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
View file @
441d584e
import
torch
import
torch
from
torch.fx
import
GraphModule
,
Tracer
from
colossalai.fx
import
ColoT
race
r
from
colossalai.fx
import
symbolic_t
race
def
trace_and_compare
(
model
,
data_gen
,
need_meta
=
False
,
need_concrete
=
False
,
kwargs_transform
=
False
):
def
trace_and_compare
(
model
,
data_gen
,
need_meta
=
False
,
need_concrete
=
False
,
kwargs_transform
=
False
):
data
=
data_gen
()
data
=
data_gen
()
concrete_args
=
data
if
need_concrete
else
{}
concrete_args
=
data
if
need_concrete
else
{}
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
data
.
items
()}
if
need_meta
else
{}
meta_args
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
data
.
items
()}
if
need_meta
else
{}
tracer
=
ColoTracer
()
model
.
eval
()
model
.
eval
()
graph
=
tracer
.
trace
(
root
=
model
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
gm
=
symbolic_trace
(
model
,
concrete_args
=
concrete_args
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
non_fx_out
=
model
(
**
data
)
non_fx_out
=
model
(
**
data
)
...
...
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
View file @
441d584e
import
pytest
import
pytest
import
torch
import
torch
from
colossalai.fx.tracer
import
meta_patch
from
colossalai.fx
import
symbolic_trace
from
colossalai.fx.tracer.meta_patch.patched_function
import
python_ops
from
colossalai.fx.tracer.tracer
import
ColoTracer
try
:
try
:
from
torchrec.models
import
deepfm
from
torchrec.models
import
deepfm
...
@@ -14,8 +12,6 @@ try:
...
@@ -14,8 +12,6 @@ try:
except
ImportError
:
except
ImportError
:
NOT_TORCHREC
=
True
NOT_TORCHREC
=
True
from
torch.fx
import
GraphModule
BATCH
=
2
BATCH
=
2
SHAPE
=
10
SHAPE
=
10
...
@@ -43,9 +39,6 @@ def test_torchrec_deepfm_models():
...
@@ -43,9 +39,6 @@ def test_torchrec_deepfm_models():
# Dense Features
# Dense Features
features
=
torch
.
rand
((
BATCH
,
SHAPE
))
features
=
torch
.
rand
((
BATCH
,
SHAPE
))
# Tracer
tracer
=
ColoTracer
()
for
model_cls
in
MODEL_LIST
:
for
model_cls
in
MODEL_LIST
:
# Initializing model
# Initializing model
if
model_cls
==
deepfm
.
DenseArch
:
if
model_cls
==
deepfm
.
DenseArch
:
...
@@ -60,9 +53,7 @@ def test_torchrec_deepfm_models():
...
@@ -60,9 +53,7 @@ def test_torchrec_deepfm_models():
model
=
model_cls
(
ebc
)
model
=
model_cls
(
ebc
)
# Setup GraphModule
# Setup GraphModule
graph
=
tracer
.
trace
(
model
)
gm
=
symbolic_trace
(
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
model
.
eval
()
model
.
eval
()
gm
.
eval
()
gm
.
eval
()
...
...
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
View file @
441d584e
import
torch
import
torch
from
colossalai.fx
.tracer.tracer
import
ColoT
race
r
from
colossalai.fx
import
symbolic_t
race
try
:
try
:
from
torchrec.models
import
dlrm
from
torchrec.models
import
dlrm
...
@@ -12,7 +12,6 @@ except ImportError:
...
@@ -12,7 +12,6 @@ except ImportError:
NOT_TORCHREC
=
True
NOT_TORCHREC
=
True
import
pytest
import
pytest
from
torch.fx
import
GraphModule
BATCH
=
2
BATCH
=
2
SHAPE
=
10
SHAPE
=
10
...
@@ -51,8 +50,6 @@ def test_torchrec_dlrm_models():
...
@@ -51,8 +50,6 @@ def test_torchrec_dlrm_models():
# Sparse Features
# Sparse Features
sparse_features
=
torch
.
rand
((
BATCH
,
len
(
keys
),
SHAPE
))
sparse_features
=
torch
.
rand
((
BATCH
,
len
(
keys
),
SHAPE
))
# Tracer
tracer
=
ColoTracer
()
for
model_cls
in
MODEL_LIST
:
for
model_cls
in
MODEL_LIST
:
# Initializing model
# Initializing model
...
@@ -77,12 +74,9 @@ def test_torchrec_dlrm_models():
...
@@ -77,12 +74,9 @@ def test_torchrec_dlrm_models():
# Setup GraphModule
# Setup GraphModule
if
model_cls
==
dlrm
.
InteractionV2Arch
:
if
model_cls
==
dlrm
.
InteractionV2Arch
:
concrete_args
=
{
"dense_features"
:
dense_features
,
"sparse_features"
:
sparse_features
}
concrete_args
=
{
"dense_features"
:
dense_features
,
"sparse_features"
:
sparse_features
}
g
raph
=
tracer
.
trace
(
model
,
concrete_args
=
concrete_args
)
g
m
=
symbolic_
trace
(
model
,
concrete_args
=
concrete_args
)
else
:
else
:
graph
=
tracer
.
trace
(
model
)
gm
=
symbolic_trace
(
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
model
.
eval
()
model
.
eval
()
gm
.
eval
()
gm
.
eval
()
...
...
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
View file @
441d584e
...
@@ -2,8 +2,8 @@ import torch
...
@@ -2,8 +2,8 @@ import torch
import
torchvision
import
torchvision
import
torchvision.models
as
tm
import
torchvision.models
as
tm
from
packaging
import
version
from
packaging
import
version
from
colossalai.fx
import
ColoTracer
from
torch
.fx
import
GraphModul
e
from
colossalai
.fx
import
symbolic_trac
e
def
test_torchvision_models
():
def
test_torchvision_models
():
...
@@ -20,7 +20,6 @@ def test_torchvision_models():
...
@@ -20,7 +20,6 @@ def test_torchvision_models():
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
deterministic
=
True
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
for
model_cls
in
MODEL_LIST
:
for
model_cls
in
MODEL_LIST
:
...
@@ -30,10 +29,7 @@ def test_torchvision_models():
...
@@ -30,10 +29,7 @@ def test_torchvision_models():
else
:
else
:
model
=
model_cls
()
model
=
model_cls
()
graph
=
tracer
.
trace
(
root
=
model
)
gm
=
symbolic_trace
(
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
model
.
eval
()
model
.
eval
()
gm
.
eval
()
gm
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment