Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
75abc75c
Unverified
Commit
75abc75c
authored
Jul 18, 2022
by
Frank Lee
Committed by
GitHub
Jul 18, 2022
Browse files
[fx] fixed compatiblity issue with torch 1.10 (#1331)
parent
069d6fdc
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
32 additions
and
28 deletions
+32
-28
colossalai/fx/passes/split_module.py
colossalai/fx/passes/split_module.py
+8
-4
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
...salai/fx/tracer/meta_patch/patched_function/arithmetic.py
+1
-0
colossalai/fx/tracer/tracer.py
colossalai/fx/tracer/tracer.py
+3
-0
tests/test_fx/test_pipeline/test_timm_model/test_timm.py
tests/test_fx/test_pipeline/test_timm_model/test_timm.py
+1
-5
tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py
...est_fx/test_pipeline/test_torchvision/test_torchvision.py
+7
-7
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
+1
-5
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
...t_tracer/test_torchvision_model/test_torchvision_model.py
+11
-7
No files found.
colossalai/fx/passes/split_module.py
View file @
75abc75c
...
...
@@ -2,6 +2,7 @@ import torch
from
torch.fx.graph_module
import
GraphModule
from
typing
import
Callable
,
List
,
Dict
,
Any
,
Optional
from
torch.fx._compatibility
import
compatibility
from
packaging
import
version
import
inspect
...
...
@@ -233,10 +234,13 @@ def split_module(
base_mod_attrs
:
Dict
[
str
,
torch
.
fx
.
graph_module
.
GraphModule
]
=
{}
for
node
in
m
.
graph
.
nodes
:
if
node
.
op
==
'placeholder'
:
default_value
=
node
.
args
[
0
]
if
len
(
node
.
args
)
>
0
else
inspect
.
Signature
.
empty
base_mod_env
[
node
.
name
]
=
base_mod_graph
.
placeholder
(
node
.
name
,
type_expr
=
node
.
type
,
default_value
=
default_value
)
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.11.0'
):
base_mod_env
[
node
.
name
]
=
base_mod_graph
.
placeholder
(
node
.
name
,
type_expr
=
node
.
type
)
else
:
default_value
=
node
.
args
[
0
]
if
len
(
node
.
args
)
>
0
else
inspect
.
Signature
.
empty
base_mod_env
[
node
.
name
]
=
base_mod_graph
.
placeholder
(
node
.
name
,
type_expr
=
node
.
type
,
default_value
=
default_value
)
base_mod_env
[
node
.
name
].
meta
=
node
.
meta
.
copy
()
# Do some things iterating over the partitions in topological order again:
...
...
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
View file @
75abc75c
...
...
@@ -3,6 +3,7 @@ from ..registry import meta_patched_function
@
meta_patched_function
.
register
(
torch
.
matmul
)
@
meta_patched_function
.
register
(
'matmul'
)
# for built-in op @
def
torch_matmul
(
input
,
other
,
*
,
out
=
None
):
# copied from huggingface.utils.fx
d1
=
input
.
dim
()
...
...
colossalai/fx/tracer/tracer.py
View file @
75abc75c
...
...
@@ -96,6 +96,9 @@ class ColoTracer(Tracer):
# fetch patched function
if
meta_patched_function
.
has
(
target
):
meta_target
=
meta_patched_function
.
get
(
target
)
elif
meta_patched_function
.
has
(
target
.
__name__
):
# use name for some builtin op like @ (matmul)
meta_target
=
meta_patched_function
.
get
(
target
.
__name__
)
else
:
meta_target
=
target
...
...
tests/test_fx/test_pipeline/test_timm_model/test_timm.py
View file @
75abc75c
import
torch
import
pytest
try
:
import
timm.models
as
tm
except
:
pass
import
timm.models
as
tm
from
timm_utils
import
split_model_and_compare_output
...
...
tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py
View file @
75abc75c
import
torch
try
:
import
torchvision.models
as
tm
except
:
pass
import
torchvision
import
torchvision.models
as
tm
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.passes.adding_split_node_pass
import
split_with_split_nodes_pass
,
balanced_split_pass
from
torch.fx
import
GraphModule
from
packaging
import
version
import
random
import
numpy
as
np
import
inspect
import
pytest
MANUAL_SEED
=
0
random
.
seed
(
MANUAL_SEED
)
...
...
@@ -22,9 +19,12 @@ torch.backends.cudnn.deterministic = True
def
test_torchvision_models
():
MODEL_LIST
=
[
tm
.
vgg11
,
tm
.
resnet18
,
tm
.
densenet121
,
tm
.
mobilenet_v3_small
,
tm
.
resnext50_32x4d
,
tm
.
wide_resnet50_2
,
tm
.
regnet_x_16gf
,
tm
.
vit_b_16
,
tm
.
convnext_small
,
tm
.
efficientnet_b0
,
tm
.
mnasnet0_5
tm
.
regnet_x_16gf
,
tm
.
efficientnet_b0
,
tm
.
mnasnet0_5
]
if
version
.
parse
(
torchvision
.
__version__
)
>=
version
.
parse
(
'0.12.0'
):
MODEL_LIST
.
extend
([
tm
.
vit_b_16
,
tm
.
convnext_small
])
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
...
...
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
View file @
75abc75c
import
torch
import
pytest
try
:
import
timm.models
as
tm
except
:
pass
import
timm.models
as
tm
from
colossalai.fx
import
ColoTracer
from
torch.fx
import
GraphModule
...
...
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
View file @
75abc75c
import
torch
import
pytest
try
:
import
torchvision.models
as
tm
except
:
pass
import
torchvision
import
torchvision.models
as
tm
from
packaging
import
version
from
colossalai.fx
import
ColoTracer
from
torch.fx
import
GraphModule
...
...
@@ -11,16 +9,22 @@ from torch.fx import GraphModule
def
test_torchvision_models
():
MODEL_LIST
=
[
tm
.
vgg11
,
tm
.
resnet18
,
tm
.
densenet121
,
tm
.
mobilenet_v3_small
,
tm
.
resnext50_32x4d
,
tm
.
wide_resnet50_2
,
tm
.
regnet_x_16gf
,
tm
.
vit_b_16
,
tm
.
convnext_small
,
tm
.
mnasnet0_5
,
tm
.
efficientnet_b0
tm
.
regnet_x_16gf
,
tm
.
mnasnet0_5
,
tm
.
efficientnet_b0
]
RANDOMIZED_MODELS
=
[
tm
.
efficientnet_b0
]
if
version
.
parse
(
torchvision
.
__version__
)
>=
version
.
parse
(
'0.12.0'
):
MODEL_LIST
.
extend
([
tm
.
vit_b_16
,
tm
.
convnext_small
])
RANDOMIZED_MODELS
.
append
(
tm
.
convnext_small
)
torch
.
backends
.
cudnn
.
deterministic
=
True
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
for
model_cls
in
MODEL_LIST
:
if
model_cls
in
[
tm
.
convnext_small
,
tm
.
efficientnet_b0
]
:
if
model_cls
in
RANDOMIZED_MODELS
:
# remove the impact of randomicity
model
=
model_cls
(
stochastic_depth_prob
=
0
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment