Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
dfa36dfe
Unverified
Commit
dfa36dfe
authored
Nov 20, 2020
by
Wenwei Zhang
Committed by
GitHub
Nov 20, 2020
Browse files
Merge pull request #652 from dreamerlin/3d
[Feature] Add 3D support in wrapper
parents
ec43b671
1a12ac75
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
349 additions
and
132 deletions
+349
-132
mmcv/cnn/__init__.py
mmcv/cnn/__init__.py
+8
-4
mmcv/cnn/bricks/__init__.py
mmcv/cnn/bricks/__init__.py
+4
-2
mmcv/cnn/bricks/wrappers.py
mmcv/cnn/bricks/wrappers.py
+65
-2
tests/test_cnn/test_wrappers.py
tests/test_cnn/test_wrappers.py
+272
-124
No files found.
mmcv/cnn/__init__.py
View file @
dfa36dfe
# Copyright (c) Open-MMLab. All rights reserved.
from
.alexnet
import
AlexNet
# yapf: disable
from
.bricks
import
(
ACTIVATION_LAYERS
,
CONV_LAYERS
,
NORM_LAYERS
,
PADDING_LAYERS
,
PLUGIN_LAYERS
,
UPSAMPLE_LAYERS
,
ContextBlock
,
Conv2d
,
ConvAWS2d
,
ConvModule
,
ConvTranspose2d
,
ConvWS2d
,
DepthwiseSeparableConvModule
,
GeneralizedAttention
,
HSigmoid
,
HSwish
,
Linear
,
MaxPool2d
,
ContextBlock
,
Conv2d
,
Conv3d
,
ConvAWS2d
,
ConvModule
,
ConvTranspose2d
,
ConvTranspose3d
,
ConvWS2d
,
DepthwiseSeparableConvModule
,
GeneralizedAttention
,
HSigmoid
,
HSwish
,
Linear
,
MaxPool2d
,
MaxPool3d
,
NonLocal1d
,
NonLocal2d
,
NonLocal3d
,
Scale
,
Swish
,
build_activation_layer
,
build_conv_layer
,
build_norm_layer
,
build_padding_layer
,
build_plugin_layer
,
build_upsample_layer
,
conv_ws_2d
,
is_norm
)
# yapf: enable
from
.resnet
import
ResNet
,
make_res_layer
from
.utils
import
(
bias_init_with_prob
,
caffe2_xavier_init
,
constant_init
,
fuse_conv_bn
,
get_model_complexity_info
,
kaiming_init
,
...
...
@@ -26,5 +29,6 @@ __all__ = [
'CONV_LAYERS'
,
'NORM_LAYERS'
,
'PADDING_LAYERS'
,
'UPSAMPLE_LAYERS'
,
'PLUGIN_LAYERS'
,
'Scale'
,
'get_model_complexity_info'
,
'conv_ws_2d'
,
'ConvAWS2d'
,
'ConvWS2d'
,
'fuse_conv_bn'
,
'DepthwiseSeparableConvModule'
,
'Linear'
,
'Conv2d'
,
'ConvTranspose2d'
,
'MaxPool2d'
'Linear'
,
'Conv2d'
,
'ConvTranspose2d'
,
'MaxPool2d'
,
'ConvTranspose3d'
,
'MaxPool3d'
,
'Conv3d'
]
mmcv/cnn/bricks/__init__.py
View file @
dfa36dfe
...
...
@@ -17,7 +17,8 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
from
.scale
import
Scale
from
.swish
import
Swish
from
.upsample
import
build_upsample_layer
from
.wrappers
import
Conv2d
,
ConvTranspose2d
,
Linear
,
MaxPool2d
from
.wrappers
import
(
Conv2d
,
Conv3d
,
ConvTranspose2d
,
ConvTranspose3d
,
Linear
,
MaxPool2d
,
MaxPool3d
)
__all__
=
[
'ConvModule'
,
'build_activation_layer'
,
'build_conv_layer'
,
...
...
@@ -27,5 +28,6 @@ __all__ = [
'ACTIVATION_LAYERS'
,
'CONV_LAYERS'
,
'NORM_LAYERS'
,
'PADDING_LAYERS'
,
'UPSAMPLE_LAYERS'
,
'PLUGIN_LAYERS'
,
'Scale'
,
'ConvAWS2d'
,
'ConvWS2d'
,
'conv_ws_2d'
,
'DepthwiseSeparableConvModule'
,
'Swish'
,
'Linear'
,
'Conv2dAdaptivePadding'
,
'Conv2d'
,
'ConvTranspose2d'
,
'MaxPool2d'
'Conv2dAdaptivePadding'
,
'Conv2d'
,
'ConvTranspose2d'
,
'MaxPool2d'
,
'ConvTranspose3d'
,
'MaxPool3d'
,
'Conv3d'
]
mmcv/cnn/bricks/wrappers.py
View file @
dfa36dfe
...
...
@@ -8,7 +8,7 @@ import math
import
torch
import
torch.nn
as
nn
from
torch.nn.modules.utils
import
_pair
from
torch.nn.modules.utils
import
_pair
,
_triple
from
.registry
import
CONV_LAYERS
,
UPSAMPLE_LAYERS
...
...
@@ -58,6 +58,27 @@ class Conv2d(nn.Conv2d):
return
super
().
forward
(
x
)
@
CONV_LAYERS
.
register_module
(
'Conv3d'
,
force
=
True
)
class
Conv3d
(
nn
.
Conv3d
):
def
forward
(
self
,
x
):
if
x
.
numel
()
==
0
and
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
4
)):
out_shape
=
[
x
.
shape
[
0
],
self
.
out_channels
]
for
i
,
k
,
p
,
s
,
d
in
zip
(
x
.
shape
[
-
3
:],
self
.
kernel_size
,
self
.
padding
,
self
.
stride
,
self
.
dilation
):
o
=
(
i
+
2
*
p
-
(
d
*
(
k
-
1
)
+
1
))
//
s
+
1
out_shape
.
append
(
o
)
empty
=
NewEmptyTensorOp
.
apply
(
x
,
out_shape
)
if
self
.
training
:
# produce dummy gradient to avoid DDP warning.
dummy
=
sum
(
x
.
view
(
-
1
)[
0
]
for
x
in
self
.
parameters
())
*
0.0
return
empty
+
dummy
else
:
return
empty
return
super
().
forward
(
x
)
@
CONV_LAYERS
.
register_module
()
@
CONV_LAYERS
.
register_module
(
'deconv'
)
@
UPSAMPLE_LAYERS
.
register_module
(
'deconv'
,
force
=
True
)
...
...
@@ -78,7 +99,30 @@ class ConvTranspose2d(nn.ConvTranspose2d):
else
:
return
empty
return
super
(
ConvTranspose2d
,
self
).
forward
(
x
)
return
super
().
forward
(
x
)
@
CONV_LAYERS
.
register_module
()
@
CONV_LAYERS
.
register_module
(
'deconv3d'
)
@
UPSAMPLE_LAYERS
.
register_module
(
'deconv3d'
,
force
=
True
)
class
ConvTranspose3d
(
nn
.
ConvTranspose3d
):
def
forward
(
self
,
x
):
if
x
.
numel
()
==
0
and
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
4
)):
out_shape
=
[
x
.
shape
[
0
],
self
.
out_channels
]
for
i
,
k
,
p
,
s
,
d
,
op
in
zip
(
x
.
shape
[
-
3
:],
self
.
kernel_size
,
self
.
padding
,
self
.
stride
,
self
.
dilation
,
self
.
output_padding
):
out_shape
.
append
((
i
-
1
)
*
s
-
2
*
p
+
(
d
*
(
k
-
1
)
+
1
)
+
op
)
empty
=
NewEmptyTensorOp
.
apply
(
x
,
out_shape
)
if
self
.
training
:
# produce dummy gradient to avoid DDP warning.
dummy
=
sum
(
x
.
view
(
-
1
)[
0
]
for
x
in
self
.
parameters
())
*
0.0
return
empty
+
dummy
else
:
return
empty
return
super
().
forward
(
x
)
class
MaxPool2d
(
nn
.
MaxPool2d
):
...
...
@@ -99,6 +143,25 @@ class MaxPool2d(nn.MaxPool2d):
return
super
().
forward
(
x
)
class
MaxPool3d
(
nn
.
MaxPool3d
):
def
forward
(
self
,
x
):
# PyTorch 1.7 does not support empty tensor inference yet
if
x
.
numel
()
==
0
and
obsolete_torch_version
(
TORCH_VERSION
,
(
1
,
7
)):
out_shape
=
list
(
x
.
shape
[:
2
])
for
i
,
k
,
p
,
s
,
d
in
zip
(
x
.
shape
[
-
3
:],
_triple
(
self
.
kernel_size
),
_triple
(
self
.
padding
),
_triple
(
self
.
stride
),
_triple
(
self
.
dilation
)):
o
=
(
i
+
2
*
p
-
(
d
*
(
k
-
1
)
+
1
))
/
s
+
1
o
=
math
.
ceil
(
o
)
if
self
.
ceil_mode
else
math
.
floor
(
o
)
out_shape
.
append
(
o
)
empty
=
NewEmptyTensorOp
.
apply
(
x
,
out_shape
)
return
empty
return
super
().
forward
(
x
)
class
Linear
(
torch
.
nn
.
Linear
):
def
forward
(
self
,
x
):
...
...
tests/test_cnn/test_wrappers.py
View file @
dfa36dfe
from
collections
import
OrderedDict
from
itertools
import
product
from
unittest.mock
import
patch
import
pytest
import
torch
import
torch.nn
as
nn
from
mmcv.cnn.bricks
import
Conv2d
,
ConvTranspose2d
,
Linear
,
MaxPool2d
from
mmcv.cnn.bricks
import
(
Conv2d
,
Conv3d
,
ConvTranspose2d
,
ConvTranspose3d
,
Linear
,
MaxPool2d
,
MaxPool3d
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_conv2d
():
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
def
test_conv2d
(
in_w
,
in_h
,
in_channel
,
out_channel
,
kernel_size
,
stride
,
padding
,
dilation
):
"""
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv2d
"""
# train mode
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_h
,
in_w
)
torch
.
manual_seed
(
0
)
wrapper
=
Conv2d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_h
,
in_w
).
requires_grad_
(
True
)
torch
.
manual_seed
(
0
)
ref
=
nn
.
Conv2d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
wrapper_out
.
sum
().
backward
()
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
test_cases
=
OrderedDict
([(
'in_w'
,
[
10
,
20
]),
(
'in_h'
,
[
10
,
20
]),
(
'in_channel'
,
[
1
,
3
]),
(
'out_channel'
,
[
1
,
3
]),
(
'kernel_size'
,
[
3
,
5
]),
(
'stride'
,
[
1
,
2
]),
(
'padding'
,
[
0
,
1
]),
(
'dilation'
,
[
1
,
2
])])
# eval mode
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_h
,
in_w
)
wrapper
=
Conv2d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
wrapper
.
eval
()
wrapper
(
x_empty
)
# train mode
for
in_h
,
in_w
,
in_cha
,
out_cha
,
k
,
s
,
p
,
d
in
product
(
*
list
(
test_cases
.
values
())):
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_h
,
in_w
)
torch
.
manual_seed
(
0
)
wrapper
=
Conv2d
(
in_cha
,
out_cha
,
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_cha
,
in_h
,
in_w
).
requires_grad_
(
True
)
torch
.
manual_seed
(
0
)
ref
=
nn
.
Conv2d
(
in_cha
,
out_cha
,
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
ref_out
=
ref
(
x_normal
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
# noqa: E501
[(
10
,
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
def
test_conv3d
(
in_w
,
in_h
,
in_t
,
in_channel
,
out_channel
,
kernel_size
,
stride
,
padding
,
dilation
):
"""
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv3d
"""
# train mode
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_t
,
in_h
,
in_w
)
torch
.
manual_seed
(
0
)
wrapper
=
Conv3d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_t
,
in_h
,
in_w
).
requires_grad_
(
True
)
torch
.
manual_seed
(
0
)
ref
=
nn
.
Conv3d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
wrapper_out
.
sum
().
backward
()
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
# eval mode
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_t
,
in_h
,
in_w
)
wrapper
=
Conv3d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
wrapper
.
eval
()
wrapper
(
x_empty
)
wrapper_out
.
sum
().
backward
()
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
def
test_conv_transposed_2d
(
in_w
,
in_h
,
in_channel
,
out_channel
,
kernel_size
,
stride
,
padding
,
dilation
):
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_h
,
in_w
,
requires_grad
=
True
)
# out padding must be smaller than either stride or dilation
op
=
min
(
stride
,
dilation
)
-
1
torch
.
manual_seed
(
0
)
wrapper
=
ConvTranspose2d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
output_padding
=
op
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_h
,
in_w
)
torch
.
manual_seed
(
0
)
ref
=
nn
.
ConvTranspose2d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
output_padding
=
op
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
wrapper_out
.
sum
().
backward
()
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
# eval mode
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_h
,
in_w
)
wrapper
=
Conv2d
(
in_cha
,
out_cha
,
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_h
,
in_w
)
wrapper
=
ConvTranspose2d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
output_padding
=
op
)
wrapper
.
eval
()
wrapper
(
x_empty
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_conv_transposed_2d
():
test_cases
=
OrderedDict
([(
'in_w'
,
[
10
,
20
]),
(
'in_h'
,
[
10
,
20
]),
(
'in_channel'
,
[
1
,
3
]),
(
'out_channel'
,
[
1
,
3
]),
(
'kernel_size'
,
[
3
,
5
]),
(
'stride'
,
[
1
,
2
]),
(
'padding'
,
[
0
,
1
]),
(
'dilation'
,
[
1
,
2
])])
for
in_h
,
in_w
,
in_cha
,
out_cha
,
k
,
s
,
p
,
d
in
product
(
*
list
(
test_cases
.
values
())):
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_h
,
in_w
,
requires_grad
=
True
)
# out padding must be smaller than either stride or dilation
op
=
min
(
s
,
d
)
-
1
torch
.
manual_seed
(
0
)
wrapper
=
ConvTranspose2d
(
in_cha
,
out_cha
,
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
,
output_padding
=
op
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_cha
,
in_h
,
in_w
)
torch
.
manual_seed
(
0
)
ref
=
nn
.
ConvTranspose2d
(
in_cha
,
out_cha
,
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
,
output_padding
=
op
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
wrapper_out
.
sum
().
backward
()
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
# noqa: E501
[(
10
,
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
def
test_conv_transposed_3d
(
in_w
,
in_h
,
in_t
,
in_channel
,
out_channel
,
kernel_size
,
stride
,
padding
,
dilation
):
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_t
,
in_h
,
in_w
,
requires_grad
=
True
)
# out padding must be smaller than either stride or dilation
op
=
min
(
stride
,
dilation
)
-
1
torch
.
manual_seed
(
0
)
wrapper
=
ConvTranspose3d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
output_padding
=
op
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_t
,
in_h
,
in_w
)
torch
.
manual_seed
(
0
)
ref
=
nn
.
ConvTranspose3d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
output_padding
=
op
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
wrapper_out
.
sum
().
backward
()
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
# eval mode
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_h
,
in_w
)
wrapper
=
ConvTranspose2d
(
in_cha
,
out_cha
,
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
,
output_padding
=
op
)
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_t
,
in_h
,
in_w
)
wrapper
=
ConvTranspose3d
(
in_channel
,
out_channel
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
output_padding
=
op
)
wrapper
.
eval
()
wrapper
(
x_empty
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_max_pool_2d
():
test_cases
=
OrderedDict
([(
'in_w'
,
[
10
,
20
]),
(
'in_h'
,
[
10
,
20
]),
(
'in_channel'
,
[
1
,
3
]),
(
'out_channel'
,
[
1
,
3
]),
(
'kernel_size'
,
[
3
,
5
]),
(
'stride'
,
[
1
,
2
]),
(
'padding'
,
[
0
,
1
]),
(
'dilation'
,
[
1
,
2
])])
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
def
test_max_pool_2d
(
in_w
,
in_h
,
in_channel
,
out_channel
,
kernel_size
,
stride
,
padding
,
dilation
):
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_h
,
in_w
,
requires_grad
=
True
)
wrapper
=
MaxPool2d
(
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_h
,
in_w
)
ref
=
nn
.
MaxPool2d
(
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
ref_out
=
ref
(
x_normal
)
for
in_h
,
in_w
,
in_cha
,
out_cha
,
k
,
s
,
p
,
d
in
product
(
*
list
(
test_cases
.
values
())):
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_cha
,
in_h
,
in_w
,
requires_grad
=
True
)
wrapper
=
MaxPool2d
(
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
wrapper_out
=
wrapper
(
x_empty
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_cha
,
in_h
,
in_w
)
ref
=
nn
.
MaxPool2d
(
k
,
stride
=
s
,
padding
=
p
,
dilation
=
d
)
ref_out
=
ref
(
x_normal
)
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
# noqa: E501
[(
10
,
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
def
test_max_pool_3d
(
in_w
,
in_h
,
in_t
,
in_channel
,
out_channel
,
kernel_size
,
stride
,
padding
,
dilation
):
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_t
,
in_h
,
in_w
,
requires_grad
=
True
)
wrapper
=
MaxPool3d
(
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
wrapper_out
=
wrapper
(
x_empty
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_t
,
in_h
,
in_w
)
ref
=
nn
.
MaxPool3d
(
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
ref_out
=
ref
(
x_normal
)
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
@
patch
(
'torch.__version__'
,
'1.1'
)
def
test_linear
():
test_cases
=
OrderedDict
([
(
'in_w'
,
[
10
,
20
]),
(
'in_h'
,
[
10
,
20
]),
(
'in_feature'
,
[
1
,
3
]),
(
'out_feature'
,
[
1
,
3
]),
])
for
in_h
,
in_w
,
in_feature
,
out_feature
in
product
(
*
list
(
test_cases
.
values
())):
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_feature
,
requires_grad
=
True
)
torch
.
manual_seed
(
0
)
wrapper
=
Linear
(
in_feature
,
out_feature
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_feature
)
torch
.
manual_seed
(
0
)
ref
=
nn
.
Linear
(
in_feature
,
out_feature
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
wrapper_out
.
sum
().
backward
()
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_feature,out_feature'
,
[(
10
,
10
,
1
,
1
),
(
20
,
20
,
3
,
3
)])
def
test_linear
(
in_w
,
in_h
,
in_feature
,
out_feature
):
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_feature
,
requires_grad
=
True
)
torch
.
manual_seed
(
0
)
wrapper
=
Linear
(
in_feature
,
out_feature
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_feature
)
torch
.
manual_seed
(
0
)
ref
=
nn
.
Linear
(
in_feature
,
out_feature
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
1
:]
==
ref_out
.
shape
[
1
:]
wrapper_out
.
sum
().
backward
()
assert
wrapper
.
weight
.
grad
is
not
None
assert
wrapper
.
weight
.
grad
.
shape
==
wrapper
.
weight
.
shape
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
# eval mode
x_empty
=
torch
.
randn
(
0
,
in_feature
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment