Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
abf6a262
"runtime/vscode:/vscode.git/clone" did not exist on "5701753ab8bce0e2d430e35b9502948695d3e4e2"
Unverified
Commit
abf6a262
authored
Jul 04, 2022
by
Frank Lee
Committed by
GitHub
Jul 04, 2022
Browse files
[fx] added module patch for pooling layers (#1197)
parent
23442a5b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
91 additions
and
1 deletion
+91
-1
colossalai/fx/tracer/meta_patch/__init__.py
colossalai/fx/tracer/meta_patch/__init__.py
+0
-1
colossalai/fx/tracer/meta_patch/patched_module.py
colossalai/fx/tracer/meta_patch/patched_module.py
+30
-0
tests/test_fx/test_tracer/test_non_patched_module.py
tests/test_fx/test_tracer/test_non_patched_module.py
+31
-0
tests/test_fx/test_tracer/test_patched_module.py
tests/test_fx/test_tracer/test_patched_module.py
+30
-0
No files found.
colossalai/fx/tracer/meta_patch/__init__.py
View file @
abf6a262
from
sys
import
meta_path
from
.registry
import
*
from
.patched_function
import
*
from
.patched_module
import
*
colossalai/fx/tracer/meta_patch/patched_module.py
View file @
abf6a262
...
...
@@ -86,3 +86,33 @@ def torch_nn_conv3d(self, input):
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
MaxPool3d
)
def
torch_nn_maxpool3d
(
self
,
input
):
num_dim
=
input
.
dim
()
assert
num_dim
in
[
4
,
5
],
f
'expected the input to have 4 or 5 dimensions, but got
{
num_dim
}
dimensions'
d_in
,
h_in
,
w_in
=
input
.
shape
[
-
3
:]
def
_convert_int_to_list
(
item
):
if
isinstance
(
item
,
int
):
return
[
item
]
*
3
else
:
return
item
padding
=
_convert_int_to_list
(
self
.
padding
)
dilation
=
_convert_int_to_list
(
self
.
dilation
)
kernel_size
=
_convert_int_to_list
(
self
.
kernel_size
)
stride
=
_convert_int_to_list
(
self
.
stride
)
d_out
=
math
.
floor
((
d_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
1
]
-
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
-
1
)
/
stride
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
2
]
-
dilation
[
2
]
*
(
kernel_size
[
2
]
-
1
)
-
1
)
/
stride
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
d_out
,
h_out
,
w_out
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
tests/test_fx/test_tracer/test_non_patched_module.py
0 → 100644
View file @
abf6a262
import
torch
import
torch.nn
def
test_maxpool
():
layer_to_test
=
dict
(
maxpool_1d
=
dict
(
layer
=
torch
.
nn
.
MaxPool1d
,
shape
=
(
4
,
3
,
4
)),
maxpool_2d
=
dict
(
layer
=
torch
.
nn
.
MaxPool2d
,
shape
=
(
4
,
3
,
4
,
4
)))
for
name
,
info
in
layer_to_test
.
items
():
data
=
torch
.
rand
(
*
info
[
'shape'
])
meta_data
=
data
.
to
(
'meta'
)
layer
=
info
[
'layer'
](
kernel_size
=
3
)
out
=
layer
(
data
)
meta_out
=
layer
(
meta_data
)
assert
meta_out
.
is_meta
assert
out
.
shape
==
meta_out
.
shape
def
test_avgpool
():
layer_to_test
=
dict
(
maxpool_1d
=
dict
(
layer
=
torch
.
nn
.
AvgPool1d
,
shape
=
(
4
,
3
,
4
)),
maxpool_2d
=
dict
(
layer
=
torch
.
nn
.
AvgPool2d
,
shape
=
(
4
,
3
,
4
,
4
)),
maxpool_3d
=
dict
(
layer
=
torch
.
nn
.
AvgPool3d
,
shape
=
(
4
,
3
,
4
,
4
,
4
)))
for
name
,
info
in
layer_to_test
.
items
():
data
=
torch
.
rand
(
*
info
[
'shape'
])
meta_data
=
data
.
to
(
'meta'
)
layer
=
info
[
'layer'
](
kernel_size
=
3
)
out
=
layer
(
data
)
meta_out
=
layer
(
meta_data
)
assert
meta_out
.
is_meta
assert
out
.
shape
==
meta_out
.
shape
tests/test_fx/test_tracer/test_patched_module.py
View file @
abf6a262
...
...
@@ -225,3 +225,33 @@ def test_conv3d():
patch_fn
=
patched_module
.
torch_nn_conv3d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
def
test_maxpool3d
():
pooler
=
torch
.
nn
.
MaxPool3d
(
kernel_size
=
3
)
# test max pool 3d
data
=
torch
.
rand
(
2
,
3
,
4
,
4
,
4
)
materialized_output
=
pooler
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patched_module
.
torch_nn_maxpool3d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
# test max pool 3d
data
=
torch
.
rand
(
2
,
3
,
4
,
4
)
materialized_output
=
pooler
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patched_module
.
torch_nn_maxpool3d
,
expect_exception
=
False
,
output_shape
=
materialized_output
.
shape
)
# test max pool 3d
data
=
torch
.
rand
(
2
,
3
,
4
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patched_module
.
torch_nn_maxpool3d
,
expect_exception
=
True
,
output_shape
=
None
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment