Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
9036241e
Unverified
Commit
9036241e
authored
Aug 03, 2023
by
youkaichao
Committed by
GitHub
Aug 03, 2023
Browse files
[Enhancement] Change the order of condition to make fx wok (#2883)
parent
f64d4858
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
7 deletions
+27
-7
mmcv/cnn/bricks/wrappers.py
mmcv/cnn/bricks/wrappers.py
+7
-7
tests/test_cnn/test_wrappers.py
tests/test_cnn/test_wrappers.py
+20
-0
No files found.
mmcv/cnn/bricks/wrappers.py
View file @
9036241e
...
...
@@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
class
Conv2d
(
nn
.
Conv2d
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
x
.
numel
()
==
0
and
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
4
)):
if
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
4
))
and
x
.
numel
()
==
0
:
out_shape
=
[
x
.
shape
[
0
],
self
.
out_channels
]
for
i
,
k
,
p
,
s
,
d
in
zip
(
x
.
shape
[
-
2
:],
self
.
kernel_size
,
self
.
padding
,
self
.
stride
,
self
.
dilation
):
...
...
@@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d):
class
Conv3d
(
nn
.
Conv3d
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
x
.
numel
()
==
0
and
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
4
)):
if
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
4
))
and
x
.
numel
()
==
0
:
out_shape
=
[
x
.
shape
[
0
],
self
.
out_channels
]
for
i
,
k
,
p
,
s
,
d
in
zip
(
x
.
shape
[
-
3
:],
self
.
kernel_size
,
self
.
padding
,
self
.
stride
,
self
.
dilation
):
...
...
@@ -84,7 +84,7 @@ class Conv3d(nn.Conv3d):
class
ConvTranspose2d
(
nn
.
ConvTranspose2d
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
x
.
numel
()
==
0
and
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
4
)):
if
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
4
))
and
x
.
numel
()
==
0
:
out_shape
=
[
x
.
shape
[
0
],
self
.
out_channels
]
for
i
,
k
,
p
,
s
,
d
,
op
in
zip
(
x
.
shape
[
-
2
:],
self
.
kernel_size
,
self
.
padding
,
self
.
stride
,
...
...
@@ -106,7 +106,7 @@ class ConvTranspose2d(nn.ConvTranspose2d):
class
ConvTranspose3d
(
nn
.
ConvTranspose3d
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
x
.
numel
()
==
0
and
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
4
)):
if
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
4
))
and
x
.
numel
()
==
0
:
out_shape
=
[
x
.
shape
[
0
],
self
.
out_channels
]
for
i
,
k
,
p
,
s
,
d
,
op
in
zip
(
x
.
shape
[
-
3
:],
self
.
kernel_size
,
self
.
padding
,
self
.
stride
,
...
...
@@ -127,7 +127,7 @@ class MaxPool2d(nn.MaxPool2d):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# PyTorch 1.9 does not support empty tensor inference yet
if
x
.
numel
()
==
0
and
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
9
)):
if
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
9
))
and
x
.
numel
()
==
0
:
out_shape
=
list
(
x
.
shape
[:
2
])
for
i
,
k
,
p
,
s
,
d
in
zip
(
x
.
shape
[
-
2
:],
_pair
(
self
.
kernel_size
),
_pair
(
self
.
padding
),
_pair
(
self
.
stride
),
...
...
@@ -145,7 +145,7 @@ class MaxPool3d(nn.MaxPool3d):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# PyTorch 1.9 does not support empty tensor inference yet
if
x
.
numel
()
==
0
and
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
9
)):
if
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
9
))
and
x
.
numel
()
==
0
:
out_shape
=
list
(
x
.
shape
[:
2
])
for
i
,
k
,
p
,
s
,
d
in
zip
(
x
.
shape
[
-
3
:],
_triple
(
self
.
kernel_size
),
_triple
(
self
.
padding
),
...
...
@@ -164,7 +164,7 @@ class Linear(torch.nn.Linear):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# empty tensor forward of Linear layer is supported in Pytorch 1.6
if
x
.
numel
()
==
0
and
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
5
)):
if
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
5
))
and
x
.
numel
()
==
0
:
out_shape
=
[
x
.
shape
[
0
],
self
.
out_features
]
empty
=
NewEmptyTensorOp
.
apply
(
x
,
out_shape
)
if
self
.
training
:
...
...
tests/test_cnn/test_wrappers.py
View file @
9036241e
...
...
@@ -4,6 +4,8 @@ from unittest.mock import patch
import
pytest
import
torch
import
torch.nn
as
nn
from
mmengine.utils
import
digit_version
from
mmengine.utils.dl_utils
import
TORCH_VERSION
from
mmcv.cnn.bricks
import
(
Conv2d
,
Conv3d
,
ConvTranspose2d
,
ConvTranspose3d
,
Linear
,
MaxPool2d
,
MaxPool3d
)
...
...
@@ -374,3 +376,21 @@ def test_nn_op_forward_called():
wrapper
=
Linear
(
3
,
3
)
wrapper
(
x_normal
)
nn_module_forward
.
assert_called_with
(
x_normal
)
@
pytest
.
mark
.
skipif
(
digit_version
(
TORCH_VERSION
)
<
digit_version
(
'1.10'
),
reason
=
'MaxPool2d and MaxPool3d will fail fx for torch<=1.9'
)
def
test_fx_compatibility
():
from
torch
import
fx
# ensure the fx trace can pass the network
for
Net
in
(
MaxPool2d
,
MaxPool3d
):
net
=
Net
(
1
)
gm_module
=
fx
.
symbolic_trace
(
net
)
# noqa: F841
for
Net
in
(
Linear
,
):
net
=
Net
(
1
,
1
)
gm_module
=
fx
.
symbolic_trace
(
net
)
# noqa: F841
for
Net
in
(
Conv2d
,
ConvTranspose2d
,
Conv3d
,
ConvTranspose3d
):
net
=
Net
(
1
,
1
,
1
)
gm_module
=
fx
.
symbolic_trace
(
net
)
# noqa: F841
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment