Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3da68d6b
Unverified
Commit
3da68d6b
authored
Aug 25, 2022
by
Frank Lee
Committed by
GitHub
Aug 25, 2022
Browse files
[fx] fixed adapative pooling size concatenation error (#1489)
parent
cde7b8a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
98 additions
and
17 deletions
+98
-17
colossalai/fx/tracer/meta_patch/patched_module/pooling.py
colossalai/fx/tracer/meta_patch/patched_module/pooling.py
+25
-17
tests/test_fx/test_tracer/test_patched_module.py
tests/test_fx/test_tracer/test_patched_module.py
+73
-0
No files found.
colossalai/fx/tracer/meta_patch/patched_module/pooling.py
View file @
3da68d6b
...
...
@@ -22,7 +22,7 @@ def torch_nn_avgpool1d(self, input):
l_out
=
math
.
floor
((
l_in
+
2
*
padding
[
0
]
-
kernel_size
[
0
])
/
stride
[
0
]
+
1
)
result_shape
=
input
.
shape
[:
-
1
]
+
(
l_out
,)
result_shape
=
tuple
(
input
.
shape
[:
-
1
]
)
+
(
l_out
,)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
...
...
@@ -46,7 +46,7 @@ def torch_nn_avgpool2d(self, input):
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
0
]
-
kernel_size
[
0
])
/
stride
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
1
]
-
kernel_size
[
1
])
/
stride
[
1
]
+
1
)
result_shape
=
input
.
shape
[:
-
2
]
+
(
result_shape
=
tuple
(
input
.
shape
[:
-
2
]
)
+
(
h_out
,
w_out
,
)
...
...
@@ -74,7 +74,7 @@ def torch_nn_avgpool3d(self, input):
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
1
]
-
kernel_size
[
1
])
/
stride
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
2
]
-
kernel_size
[
2
])
/
stride
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
result_shape
=
tuple
(
input
.
shape
[:
-
3
]
)
+
(
d_out
,
h_out
,
w_out
,
...
...
@@ -102,7 +102,7 @@ def torch_nn_maxpool1d(self, input):
l_out
=
math
.
floor
((
l_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
result_shape
=
input
.
shape
[:
-
1
]
+
(
l_out
,)
result_shape
=
tuple
(
input
.
shape
[:
-
1
]
)
+
(
l_out
,)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
...
...
@@ -127,7 +127,7 @@ def torch_nn_maxpool2d(self, input):
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
/
stride
[
0
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
1
]
-
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
-
1
)
/
stride
[
1
]
+
1
)
result_shape
=
input
.
shape
[:
-
2
]
+
(
result_shape
=
tuple
(
input
.
shape
[:
-
2
]
)
+
(
h_out
,
w_out
,
)
...
...
@@ -156,7 +156,7 @@ def torch_nn_maxpool3d(self, input):
h_out
=
math
.
floor
((
h_in
+
2
*
padding
[
1
]
-
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
-
1
)
/
stride
[
1
]
+
1
)
w_out
=
math
.
floor
((
w_in
+
2
*
padding
[
2
]
-
dilation
[
2
]
*
(
kernel_size
[
2
]
-
1
)
-
1
)
/
stride
[
2
]
+
1
)
result_shape
=
input
.
shape
[:
-
3
]
+
(
result_shape
=
tuple
(
input
.
shape
[:
-
3
]
)
+
(
d_out
,
h_out
,
w_out
,
...
...
@@ -167,26 +167,34 @@ def torch_nn_maxpool3d(self, input):
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool1d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool1d
)
def
torch_nn_adapative_pooling_1d
(
self
,
input
):
result_shape
=
input
.
shape
[:
-
1
]
+
(
self
.
output_size
,)
assert
input
.
dim
()
in
[
2
,
3
]
if
isinstance
(
self
.
output_size
,
int
):
output_size
=
(
self
.
output_size
,)
else
:
output_size
=
self
.
output_size
result_shape
=
tuple
(
input
.
shape
[:
-
1
])
+
output_size
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool2d
)
def
torch_nn_adapative_pooling_2d
(
self
,
input
):
result_shape
=
input
.
shape
[:
-
2
]
+
(
self
.
output_size
,
self
.
output_size
,
)
assert
input
.
dim
()
in
[
3
,
4
]
if
isinstance
(
self
.
output_size
,
int
):
output_size
=
(
self
.
output_size
,)
*
2
else
:
output_size
=
self
.
output_size
result_shape
=
tuple
(
input
.
shape
[:
-
2
])
+
output_size
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveAvgPool3d
)
@
meta_patched_module
.
register
(
torch
.
nn
.
AdaptiveMaxPool3d
)
def
torch_nn_adapative_pooling_3d
(
self
,
input
):
result_shape
=
input
.
shape
[:
-
3
]
+
(
self
.
output_size
,
self
.
output_size
,
self
.
output_size
,
)
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
\ No newline at end of file
assert
input
.
dim
()
in
[
4
,
5
]
if
isinstance
(
self
.
output_size
,
int
):
output_size
=
(
self
.
output_size
,)
*
3
else
:
output_size
=
self
.
output_size
result_shape
=
tuple
(
input
.
shape
[:
-
3
])
+
output_size
return
torch
.
empty
(
result_shape
,
device
=
'meta'
)
tests/test_fx/test_tracer/test_patched_module.py
View file @
3da68d6b
...
...
@@ -407,3 +407,76 @@ def test_pool3d():
# test max pool 3d
data
=
torch
.
rand
(
2
,
3
,
4
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patch_func
,
expect_exception
=
True
,
output_shape
=
None
)
# adapative pooling is different from other pooling, so test it individually
def
test_adaptive_pooling_1d
():
pooler
=
torch
.
nn
.
AdaptiveAvgPool1d
(
output_size
=
3
)
patch_func
=
patched_module
.
torch_nn_adapative_pooling_1d
data
=
torch
.
rand
(
3
,
4
)
output
=
pooler
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patch_func
,
expect_exception
=
False
,
output_shape
=
output
.
shape
)
data
=
torch
.
rand
(
2
,
3
,
4
)
output
=
pooler
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patch_func
,
expect_exception
=
False
,
output_shape
=
output
.
shape
)
data
=
torch
.
rand
(
2
,
3
,
4
,
5
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patch_func
,
expect_exception
=
True
,
output_shape
=
None
)
def
test_adaptive_pooling_2d
():
pooler
=
torch
.
nn
.
AdaptiveAvgPool2d
(
output_size
=
3
)
patch_func
=
patched_module
.
torch_nn_adapative_pooling_2d
data
=
torch
.
rand
(
3
,
4
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patch_func
,
expect_exception
=
True
,
output_shape
=
None
)
data
=
torch
.
rand
(
2
,
3
,
4
)
output
=
pooler
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patch_func
,
expect_exception
=
False
,
output_shape
=
output
.
shape
)
data
=
torch
.
rand
(
2
,
3
,
4
,
5
)
output
=
pooler
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patch_func
,
expect_exception
=
False
,
output_shape
=
output
.
shape
)
def
test_adaptive_pooling_3d
():
pooler
=
torch
.
nn
.
AdaptiveAvgPool3d
(
output_size
=
3
)
patch_func
=
patched_module
.
torch_nn_adapative_pooling_3d
data
=
torch
.
rand
(
3
,
4
,
5
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patch_func
,
expect_exception
=
True
,
output_shape
=
None
)
data
=
torch
.
rand
(
2
,
3
,
4
,
5
)
output
=
pooler
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patch_func
,
expect_exception
=
False
,
output_shape
=
output
.
shape
)
data
=
torch
.
rand
(
2
,
3
,
4
,
5
,
6
)
output
=
pooler
(
data
)
_assert_output_shape
(
data
=
data
,
module
=
pooler
,
patch_fn
=
patch_func
,
expect_exception
=
False
,
output_shape
=
output
.
shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment