Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
84f2298a
Unverified
Commit
84f2298a
authored
Jul 07, 2022
by
Frank Lee
Committed by
GitHub
Jul 07, 2022
Browse files
[fx] added patches for tracing swin transformer (#1228)
parent
37fcf96b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
8 deletions
+35
-8
colossalai/fx/tracer/meta_patch/patched_function.py
colossalai/fx/tracer/meta_patch/patched_function.py
+5
-0
colossalai/fx/tracer/meta_patch/patched_module.py
colossalai/fx/tracer/meta_patch/patched_module.py
+28
-0
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
+2
-8
No files found.
colossalai/fx/tracer/meta_patch/patched_function.py
View file @
84f2298a
...
@@ -218,3 +218,8 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
...
@@ -218,3 +218,8 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
concatenated_dim
=
sum
(
shape
[
dim
]
for
shape
in
shapes
)
concatenated_dim
=
sum
(
shape
[
dim
]
for
shape
in
shapes
)
final_shape
=
shape
[:
dim
]
+
[
concatenated_dim
]
+
shape
[
dim
+
1
:]
final_shape
=
shape
[:
dim
]
+
[
concatenated_dim
]
+
shape
[
dim
+
1
:]
return
torch
.
empty
(
final_shape
,
device
=
"meta"
)
return
torch
.
empty
(
final_shape
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
roll
)
def
torch_roll
(
input
,
shifts
,
dims
=
None
):
return
torch
.
empty
(
input
.
shape
,
device
=
'meta'
)
colossalai/fx/tracer/meta_patch/patched_module.py
View file @
84f2298a
...
@@ -249,6 +249,34 @@ def torch_nn_maxpool3d(self, input):
...
@@ -249,6 +249,34 @@ def torch_nn_maxpool3d(self, input):
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool1d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool1d
)
def
torch_nn_adapative_pooling_1d
(
self
,
input
):
result_shape
=
input
.
shape
[:
-
1
]
+
(
self
.
output_size
,)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool2d
)
def
torch_nn_adapative_pooling_2d
(
self
,
input
):
result_shape
=
input
.
shape
[:
-
2
]
+
(
self
.
output_size
,
self
.
output_size
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool3d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool3d
)
def
torch_nn_adapative_pooling_3d
(
self
,
input
):
result_shape
=
input
.
shape
[:
-
3
]
+
(
self
.
output_size
,
self
.
output_size
,
self
.
output_size
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
ReLU
)
@
meta_patched_module
.
register
(
torch
.
nn
.
ReLU
)
@
meta_patched_module
.
register
(
torch
.
nn
.
ReLU6
)
@
meta_patched_module
.
register
(
torch
.
nn
.
ReLU6
)
def
torch_nn_func_relu
(
self
,
input
):
def
torch_nn_func_relu
(
self
,
input
):
...
...
tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
View file @
84f2298a
...
@@ -63,14 +63,8 @@ def test_timm_models_with_control_flow():
...
@@ -63,14 +63,8 @@ def test_timm_models_with_control_flow():
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
deterministic
=
True
MODEL_LIST_WITH_CONTROL_FLOW
=
[
MODEL_LIST_WITH_CONTROL_FLOW
=
[
tm
.
convnext
.
convnext_base
,
tm
.
convnext
.
convnext_base
,
tm
.
vgg
.
vgg11
,
tm
.
dpn
.
dpn68
,
tm
.
densenet
.
densenet121
,
tm
.
rexnet
.
rexnet_100
,
tm
.
vgg
.
vgg11
,
tm
.
swin_transformer
.
swin_base_patch4_window7_224
tm
.
dpn
.
dpn68
,
tm
.
densenet
.
densenet121
,
tm
.
rexnet
.
rexnet_100
,
# not traceable
# tm.swin_transformer.swin_base_patch4_window7_224
]
]
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment