Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
86d9f468
Commit
86d9f468
authored
Nov 15, 2020
by
dreamerlin
Browse files
add MaxPool3d
parent
8ccea202
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
53 additions
and
7 deletions
+53
-7
mmcv/cnn/__init__.py
mmcv/cnn/__init__.py
+4
-3
mmcv/cnn/bricks/__init__.py
mmcv/cnn/bricks/__init__.py
+2
-2
mmcv/cnn/bricks/wrappers.py
mmcv/cnn/bricks/wrappers.py
+20
-1
tests/test_cnn/test_wrappers.py
tests/test_cnn/test_wrappers.py
+27
-1
No files found.
mmcv/cnn/__init__.py
View file @
86d9f468
...
@@ -6,8 +6,8 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
...
@@ -6,8 +6,8 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
ContextBlock
,
Conv2d
,
ConvAWS2d
,
ConvModule
,
ContextBlock
,
Conv2d
,
ConvAWS2d
,
ConvModule
,
ConvTranspose2d
,
ConvTranspose3d
,
ConvWS2d
,
ConvTranspose2d
,
ConvTranspose3d
,
ConvWS2d
,
DepthwiseSeparableConvModule
,
GeneralizedAttention
,
DepthwiseSeparableConvModule
,
GeneralizedAttention
,
HSigmoid
,
HSwish
,
Linear
,
MaxPool2d
,
NonLocal1
d
,
HSigmoid
,
HSwish
,
Linear
,
MaxPool2d
,
MaxPool3
d
,
NonLocal2d
,
NonLocal3d
,
Scale
,
Swish
,
NonLocal1d
,
NonLocal2d
,
NonLocal3d
,
Scale
,
Swish
,
build_activation_layer
,
build_conv_layer
,
build_activation_layer
,
build_conv_layer
,
build_norm_layer
,
build_padding_layer
,
build_plugin_layer
,
build_norm_layer
,
build_padding_layer
,
build_plugin_layer
,
build_upsample_layer
,
conv_ws_2d
,
is_norm
)
build_upsample_layer
,
conv_ws_2d
,
is_norm
)
...
@@ -29,5 +29,6 @@ __all__ = [
...
@@ -29,5 +29,6 @@ __all__ = [
'CONV_LAYERS'
,
'NORM_LAYERS'
,
'PADDING_LAYERS'
,
'UPSAMPLE_LAYERS'
,
'CONV_LAYERS'
,
'NORM_LAYERS'
,
'PADDING_LAYERS'
,
'UPSAMPLE_LAYERS'
,
'PLUGIN_LAYERS'
,
'Scale'
,
'get_model_complexity_info'
,
'conv_ws_2d'
,
'PLUGIN_LAYERS'
,
'Scale'
,
'get_model_complexity_info'
,
'conv_ws_2d'
,
'ConvAWS2d'
,
'ConvWS2d'
,
'fuse_conv_bn'
,
'DepthwiseSeparableConvModule'
,
'ConvAWS2d'
,
'ConvWS2d'
,
'fuse_conv_bn'
,
'DepthwiseSeparableConvModule'
,
'Linear'
,
'Conv2d'
,
'ConvTranspose2d'
,
'MaxPool2d'
,
'ConvTranspose3d'
'Linear'
,
'Conv2d'
,
'ConvTranspose2d'
,
'MaxPool2d'
,
'ConvTranspose3d'
,
'MaxPool3d'
]
]
mmcv/cnn/bricks/__init__.py
View file @
86d9f468
...
@@ -18,7 +18,7 @@ from .scale import Scale
...
@@ -18,7 +18,7 @@ from .scale import Scale
from
.swish
import
Swish
from
.swish
import
Swish
from
.upsample
import
build_upsample_layer
from
.upsample
import
build_upsample_layer
from
.wrappers
import
(
Conv2d
,
ConvTranspose2d
,
ConvTranspose3d
,
Linear
,
from
.wrappers
import
(
Conv2d
,
ConvTranspose2d
,
ConvTranspose3d
,
Linear
,
MaxPool2d
)
MaxPool2d
,
MaxPool3d
)
__all__
=
[
__all__
=
[
'ConvModule'
,
'build_activation_layer'
,
'build_conv_layer'
,
'ConvModule'
,
'build_activation_layer'
,
'build_conv_layer'
,
...
@@ -29,5 +29,5 @@ __all__ = [
...
@@ -29,5 +29,5 @@ __all__ = [
'UPSAMPLE_LAYERS'
,
'PLUGIN_LAYERS'
,
'Scale'
,
'ConvAWS2d'
,
'ConvWS2d'
,
'UPSAMPLE_LAYERS'
,
'PLUGIN_LAYERS'
,
'Scale'
,
'ConvAWS2d'
,
'ConvWS2d'
,
'conv_ws_2d'
,
'DepthwiseSeparableConvModule'
,
'Swish'
,
'Linear'
,
'conv_ws_2d'
,
'DepthwiseSeparableConvModule'
,
'Swish'
,
'Linear'
,
'Conv2dAdaptivePadding'
,
'Conv2d'
,
'ConvTranspose2d'
,
'MaxPool2d'
,
'Conv2dAdaptivePadding'
,
'Conv2d'
,
'ConvTranspose2d'
,
'MaxPool2d'
,
'ConvTranspose3d'
'ConvTranspose3d'
,
'MaxPool3d'
]
]
mmcv/cnn/bricks/wrappers.py
View file @
86d9f468
...
@@ -8,7 +8,7 @@ import math
...
@@ -8,7 +8,7 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn.modules.utils
import
_pair
from
torch.nn.modules.utils
import
_pair
,
_triple
from
.registry
import
CONV_LAYERS
,
UPSAMPLE_LAYERS
from
.registry
import
CONV_LAYERS
,
UPSAMPLE_LAYERS
...
@@ -122,6 +122,25 @@ class MaxPool2d(nn.MaxPool2d):
...
@@ -122,6 +122,25 @@ class MaxPool2d(nn.MaxPool2d):
return
super
().
forward
(
x
)
return
super
().
forward
(
x
)
class
MaxPool3d
(
nn
.
MaxPool3d
):
def
forward
(
self
,
x
):
# PyTorch 1.7 does not support empty tensor inference yet
if
x
.
numel
()
==
0
and
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
7
)):
out_shape
=
list
(
x
.
shape
[:
2
])
for
i
,
k
,
p
,
s
,
d
in
zip
(
x
.
shape
[
-
3
:],
_triple
(
self
.
kernel_size
),
_triple
(
self
.
padding
),
_triple
(
self
.
stride
),
_triple
(
self
.
dilation
)):
o
=
(
i
+
2
*
p
-
(
d
*
(
k
-
1
)
+
1
))
/
s
+
1
o
=
math
.
ceil
(
o
)
if
self
.
ceil_mode
else
math
.
floor
(
o
)
out_shape
.
append
(
o
)
empty
=
NewEmptyTensorOp
.
apply
(
x
,
out_shape
)
return
empty
return
super
().
forward
(
x
)
class
Linear
(
torch
.
nn
.
Linear
):
class
Linear
(
torch
.
nn
.
Linear
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
...
tests/test_cnn/test_wrappers.py
View file @
86d9f468
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv.cnn.bricks
import
(
Conv2d
,
ConvTranspose2d
,
ConvTranspose3d
,
Linear
,
from
mmcv.cnn.bricks
import
(
Conv2d
,
ConvTranspose2d
,
ConvTranspose3d
,
Linear
,
MaxPool2d
)
MaxPool2d
,
MaxPool3d
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
'1.1'
)
...
@@ -186,6 +186,32 @@ def test_max_pool_2d():
...
@@ -186,6 +186,32 @@ def test_max_pool_2d():
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_max_pool_3d
():
test_cases
=
OrderedDict
([(
'in_w'
,
[
10
,
20
]),
(
'in_h'
,
[
10
,
20
]),
(
'in_t'
,
[
10
,
20
]),
(
'in_channel'
,
[
1
,
3
]),
(
'out_channel'
,
[
1
,
3
]),
(
'kernel_size'
,
[
3
,
5
]),
(
'stride'
,
[
1
,
2
]),
(
'padding'
,
[
0
,
1
]),
(
'dilation'
,
[
1
,
2
])])
for
in_h
,
in_w
,
in_t
,
in_cha
,
out_cha
,
k
,
s
,
p
,
d
in
product
(
*
list
(
test_cases
.
values
())):
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_t
,
in_h
,
in_w
,
requires_grad
=
True
)
wrapper
=
MaxPool3d
(
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_cha
,
in_t
,
in_h
,
in_w
)
ref
=
nn
.
MaxPool3d
(
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_linear
():
def
test_linear
():
test_cases
=
OrderedDict
([
test_cases
=
OrderedDict
([
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment