Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
75abc75c
Unverified
Commit
75abc75c
authored
Jul 18, 2022
by
Frank Lee
Committed by
GitHub
Jul 18, 2022
Browse files
[fx] fixed compatiblity issue with torch 1.10 (#1331)
parent
069d6fdc
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
32 additions
and
28 deletions
+32
-28
colossalai/fx/passes/split_module.py
colossalai/fx/passes/split_module.py
+8
-4
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
...salai/fx/tracer/meta_patch/patched_function/arithmetic.py
+1
-0
colossalai/fx/tracer/tracer.py
colossalai/fx/tracer/tracer.py
+3
-0
tests/test_fx/test_pipeline/test_timm_model/test_timm.py
tests/test_fx/test_pipeline/test_timm_model/test_timm.py
+1
-5
tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py
...est_fx/test_pipeline/test_torchvision/test_torchvision.py
+7
-7
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
+1
-5
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
...t_tracer/test_torchvision_model/test_torchvision_model.py
+11
-7
No files found.
colossalai/fx/passes/split_module.py
View file @
75abc75c
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
from
torch.fx.graph_module
import
GraphModule
from
torch.fx.graph_module
import
GraphModule
from
typing
import
Callable
,
List
,
Dict
,
Any
,
Optional
from
typing
import
Callable
,
List
,
Dict
,
Any
,
Optional
from
torch.fx._compatibility
import
compatibility
from
torch.fx._compatibility
import
compatibility
from
packaging
import
version
import
inspect
import
inspect
...
@@ -233,6 +234,9 @@ def split_module(
...
@@ -233,6 +234,9 @@ def split_module(
base_mod_attrs
:
Dict
[
str
,
torch
.
fx
.
graph_module
.
GraphModule
]
=
{}
base_mod_attrs
:
Dict
[
str
,
torch
.
fx
.
graph_module
.
GraphModule
]
=
{}
for
node
in
m
.
graph
.
nodes
:
for
node
in
m
.
graph
.
nodes
:
if
node
.
op
==
'placeholder'
:
if
node
.
op
==
'placeholder'
:
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.11.0'
):
base_mod_env
[
node
.
name
]
=
base_mod_graph
.
placeholder
(
node
.
name
,
type_expr
=
node
.
type
)
else
:
default_value
=
node
.
args
[
0
]
if
len
(
node
.
args
)
>
0
else
inspect
.
Signature
.
empty
default_value
=
node
.
args
[
0
]
if
len
(
node
.
args
)
>
0
else
inspect
.
Signature
.
empty
base_mod_env
[
node
.
name
]
=
base_mod_graph
.
placeholder
(
node
.
name
,
base_mod_env
[
node
.
name
]
=
base_mod_graph
.
placeholder
(
node
.
name
,
type_expr
=
node
.
type
,
type_expr
=
node
.
type
,
...
...
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
View file @
75abc75c
...
@@ -3,6 +3,7 @@ from ..registry import meta_patched_function
...
@@ -3,6 +3,7 @@ from ..registry import meta_patched_function
@
meta_patched_function
.
register
(
torch
.
matmul
)
@
meta_patched_function
.
register
(
torch
.
matmul
)
@
meta_patched_function
.
register
(
'matmul'
)
# for built-in op @
def
torch_matmul
(
input
,
other
,
*
,
out
=
None
):
def
torch_matmul
(
input
,
other
,
*
,
out
=
None
):
# copied from huggingface.utils.fx
# copied from huggingface.utils.fx
d1
=
input
.
dim
()
d1
=
input
.
dim
()
...
...
colossalai/fx/tracer/tracer.py
View file @
75abc75c
...
@@ -96,6 +96,9 @@ class ColoTracer(Tracer):
...
@@ -96,6 +96,9 @@ class ColoTracer(Tracer):
# fetch patched function
# fetch patched function
if
meta_patched_function
.
has
(
target
):
if
meta_patched_function
.
has
(
target
):
meta_target
=
meta_patched_function
.
get
(
target
)
meta_target
=
meta_patched_function
.
get
(
target
)
elif
meta_patched_function
.
has
(
target
.
__name__
):
# use name for some builtin op like @ (matmul)
meta_target
=
meta_patched_function
.
get
(
target
.
__name__
)
else
:
else
:
meta_target
=
target
meta_target
=
target
...
...
tests/test_fx/test_pipeline/test_timm_model/test_timm.py
View file @
75abc75c
import
torch
import
torch
import
pytest
import
timm.models
as
tm
try
:
import
timm.models
as
tm
except
:
pass
from
timm_utils
import
split_model_and_compare_output
from
timm_utils
import
split_model_and_compare_output
...
...
tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py
View file @
75abc75c
import
torch
import
torch
try
:
import
torchvision
import
torchvision.models
as
tm
import
torchvision.models
as
tm
except
:
pass
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.passes.adding_split_node_pass
import
split_with_split_nodes_pass
,
balanced_split_pass
from
colossalai.fx.passes.adding_split_node_pass
import
split_with_split_nodes_pass
,
balanced_split_pass
from
torch.fx
import
GraphModule
from
torch.fx
import
GraphModule
from
packaging
import
version
import
random
import
random
import
numpy
as
np
import
numpy
as
np
import
inspect
import
inspect
import
pytest
MANUAL_SEED
=
0
MANUAL_SEED
=
0
random
.
seed
(
MANUAL_SEED
)
random
.
seed
(
MANUAL_SEED
)
...
@@ -22,9 +19,12 @@ torch.backends.cudnn.deterministic = True
...
@@ -22,9 +19,12 @@ torch.backends.cudnn.deterministic = True
def
test_torchvision_models
():
def
test_torchvision_models
():
MODEL_LIST
=
[
MODEL_LIST
=
[
tm
.
vgg11
,
tm
.
resnet18
,
tm
.
densenet121
,
tm
.
mobilenet_v3_small
,
tm
.
resnext50_32x4d
,
tm
.
wide_resnet50_2
,
tm
.
vgg11
,
tm
.
resnet18
,
tm
.
densenet121
,
tm
.
mobilenet_v3_small
,
tm
.
resnext50_32x4d
,
tm
.
wide_resnet50_2
,
tm
.
regnet_x_16gf
,
tm
.
vit_b_16
,
tm
.
convnext_small
,
tm
.
efficientnet_b0
,
tm
.
mnasnet0_5
tm
.
regnet_x_16gf
,
tm
.
efficientnet_b0
,
tm
.
mnasnet0_5
]
]
if
version
.
parse
(
torchvision
.
__version__
)
>=
version
.
parse
(
'0.12.0'
):
MODEL_LIST
.
extend
([
tm
.
vit_b_16
,
tm
.
convnext_small
])
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
...
...
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
View file @
75abc75c
import
torch
import
torch
import
pytest
import
timm.models
as
tm
try
:
import
timm.models
as
tm
except
:
pass
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
from
torch.fx
import
GraphModule
from
torch.fx
import
GraphModule
...
...
tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
View file @
75abc75c
import
torch
import
torch
import
pytest
import
torchvision
try
:
import
torchvision.models
as
tm
import
torchvision.models
as
tm
from
packaging
import
version
except
:
pass
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
from
torch.fx
import
GraphModule
from
torch.fx
import
GraphModule
...
@@ -11,16 +9,22 @@ from torch.fx import GraphModule
...
@@ -11,16 +9,22 @@ from torch.fx import GraphModule
def
test_torchvision_models
():
def
test_torchvision_models
():
MODEL_LIST
=
[
MODEL_LIST
=
[
tm
.
vgg11
,
tm
.
resnet18
,
tm
.
densenet121
,
tm
.
mobilenet_v3_small
,
tm
.
resnext50_32x4d
,
tm
.
wide_resnet50_2
,
tm
.
vgg11
,
tm
.
resnet18
,
tm
.
densenet121
,
tm
.
mobilenet_v3_small
,
tm
.
resnext50_32x4d
,
tm
.
wide_resnet50_2
,
tm
.
regnet_x_16gf
,
tm
.
vit_b_16
,
tm
.
convnext_small
,
tm
.
mnasnet0_5
,
tm
.
efficientnet_b0
tm
.
regnet_x_16gf
,
tm
.
mnasnet0_5
,
tm
.
efficientnet_b0
]
]
RANDOMIZED_MODELS
=
[
tm
.
efficientnet_b0
]
if
version
.
parse
(
torchvision
.
__version__
)
>=
version
.
parse
(
'0.12.0'
):
MODEL_LIST
.
extend
([
tm
.
vit_b_16
,
tm
.
convnext_small
])
RANDOMIZED_MODELS
.
append
(
tm
.
convnext_small
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
deterministic
=
True
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
data
=
torch
.
rand
(
2
,
3
,
224
,
224
)
for
model_cls
in
MODEL_LIST
:
for
model_cls
in
MODEL_LIST
:
if
model_cls
in
[
tm
.
convnext_small
,
tm
.
efficientnet_b0
]
:
if
model_cls
in
RANDOMIZED_MODELS
:
# remove the impact of randomicity
# remove the impact of randomicity
model
=
model_cls
(
stochastic_depth_prob
=
0
)
model
=
model_cls
(
stochastic_depth_prob
=
0
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment