Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
655f3c3f
Unverified
Commit
655f3c3f
authored
Dec 24, 2020
by
BigBigDream
Committed by
GitHub
Dec 24, 2020
Browse files
fix mmcv_ci test_wrappers.py for parrots (#758)
parent
86e0d62a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
8 deletions
+21
-8
tests/test_cnn/test_wrappers.py
tests/test_cnn/test_wrappers.py
+21
-8
No files found.
tests/test_cnn/test_wrappers.py
View file @
655f3c3f
...
@@ -7,8 +7,13 @@ import torch.nn as nn
...
@@ -7,8 +7,13 @@ import torch.nn as nn
from
mmcv.cnn.bricks
import
(
Conv2d
,
Conv3d
,
ConvTranspose2d
,
ConvTranspose3d
,
from
mmcv.cnn.bricks
import
(
Conv2d
,
Conv3d
,
ConvTranspose2d
,
ConvTranspose3d
,
Linear
,
MaxPool2d
,
MaxPool3d
)
Linear
,
MaxPool2d
,
MaxPool3d
)
if
torch
.
__version__
!=
'parrots'
:
torch_version
=
'1.1'
else
:
torch_version
=
'parrots'
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
torch_version
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
...
@@ -65,7 +70,7 @@ def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
...
@@ -65,7 +70,7 @@ def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
wrapper
(
x_empty
)
wrapper
(
x_empty
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
torch_version
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
# noqa: E501
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
# noqa: E501
[(
10
,
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
[(
10
,
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
...
@@ -123,7 +128,7 @@ def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride,
...
@@ -123,7 +128,7 @@ def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride,
wrapper
(
x_empty
)
wrapper
(
x_empty
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
torch_version
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
...
@@ -133,6 +138,8 @@ def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size,
...
@@ -133,6 +138,8 @@ def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size,
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_h
,
in_w
,
requires_grad
=
True
)
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_h
,
in_w
,
requires_grad
=
True
)
# out padding must be smaller than either stride or dilation
# out padding must be smaller than either stride or dilation
op
=
min
(
stride
,
dilation
)
-
1
op
=
min
(
stride
,
dilation
)
-
1
if
torch
.
__version__
==
'parrots'
:
op
=
0
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
wrapper
=
ConvTranspose2d
(
wrapper
=
ConvTranspose2d
(
in_channel
,
in_channel
,
...
@@ -180,7 +187,7 @@ def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size,
...
@@ -180,7 +187,7 @@ def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size,
wrapper
(
x_empty
)
wrapper
(
x_empty
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
torch_version
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
# noqa: E501
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
# noqa: E501
[(
10
,
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
[(
10
,
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
...
@@ -237,7 +244,7 @@ def test_conv_transposed_3d(in_w, in_h, in_t, in_channel, out_channel,
...
@@ -237,7 +244,7 @@ def test_conv_transposed_3d(in_w, in_h, in_t, in_channel, out_channel,
wrapper
(
x_empty
)
wrapper
(
x_empty
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
torch_version
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
[(
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
...
@@ -261,22 +268,28 @@ def test_max_pool_2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
...
@@ -261,22 +268,28 @@ def test_max_pool_2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
torch_version
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
# noqa: E501
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation'
,
# noqa: E501
[(
10
,
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
[(
10
,
10
,
10
,
1
,
1
,
3
,
1
,
0
,
1
),
(
20
,
20
,
20
,
3
,
3
,
5
,
2
,
1
,
2
)])
@
pytest
.
mark
.
skipif
(
torch
.
__version__
==
'parrots'
and
not
torch
.
cuda
.
is_available
(),
reason
=
'parrots requires CUDA support'
)
def
test_max_pool_3d
(
in_w
,
in_h
,
in_t
,
in_channel
,
out_channel
,
kernel_size
,
def
test_max_pool_3d
(
in_w
,
in_h
,
in_t
,
in_channel
,
out_channel
,
kernel_size
,
stride
,
padding
,
dilation
):
stride
,
padding
,
dilation
):
# wrapper op with 0-dim input
# wrapper op with 0-dim input
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_t
,
in_h
,
in_w
,
requires_grad
=
True
)
x_empty
=
torch
.
randn
(
0
,
in_channel
,
in_t
,
in_h
,
in_w
,
requires_grad
=
True
)
wrapper
=
MaxPool3d
(
wrapper
=
MaxPool3d
(
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
if
torch
.
__version__
==
'parrots'
:
x_empty
=
x_empty
.
cuda
()
wrapper_out
=
wrapper
(
x_empty
)
wrapper_out
=
wrapper
(
x_empty
)
# torch op with 3-dim input as shape reference
# torch op with 3-dim input as shape reference
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_t
,
in_h
,
in_w
)
x_normal
=
torch
.
randn
(
3
,
in_channel
,
in_t
,
in_h
,
in_w
)
ref
=
nn
.
MaxPool3d
(
ref
=
nn
.
MaxPool3d
(
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
)
if
torch
.
__version__
==
'parrots'
:
x_normal
=
x_normal
.
cuda
()
ref_out
=
ref
(
x_normal
)
ref_out
=
ref
(
x_normal
)
assert
wrapper_out
.
shape
[
0
]
==
0
assert
wrapper_out
.
shape
[
0
]
==
0
...
@@ -285,7 +298,7 @@ def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size,
...
@@ -285,7 +298,7 @@ def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size,
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
assert
torch
.
equal
(
wrapper
(
x_normal
),
ref_out
)
@
patch
(
'torch.__version__'
,
'1.1'
)
@
patch
(
'torch.__version__'
,
torch_version
)
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_feature,out_feature'
,
[(
10
,
10
,
1
,
1
),
@
pytest
.
mark
.
parametrize
(
'in_w,in_h,in_feature,out_feature'
,
[(
10
,
10
,
1
,
1
),
(
20
,
20
,
3
,
3
)])
(
20
,
20
,
3
,
3
)])
def
test_linear
(
in_w
,
in_h
,
in_feature
,
out_feature
):
def
test_linear
(
in_w
,
in_h
,
in_feature
,
out_feature
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment