Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
7b18b977
Unverified
Commit
7b18b977
authored
Aug 16, 2020
by
Cao Yuhang
Committed by
GitHub
Aug 16, 2020
Browse files
fix saconv (#489)
* fix saconv * add parrots condition * add unittest * fix torch version
parent
eacaf475
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
2 deletions
+55
-2
mmcv/ops/saconv.py
mmcv/ops/saconv.py
+9
-2
tests/test_ops/test_saconv.py
tests/test_ops/test_saconv.py
+46
-0
No files found.
mmcv/ops/saconv.py
View file @
7b18b977
...
@@ -4,6 +4,7 @@ import torch.nn.functional as F
...
@@ -4,6 +4,7 @@ import torch.nn.functional as F
from
mmcv.cnn
import
CONV_LAYERS
,
ConvAWS2d
,
constant_init
from
mmcv.cnn
import
CONV_LAYERS
,
ConvAWS2d
,
constant_init
from
mmcv.ops.deform_conv
import
deform_conv2d
from
mmcv.ops.deform_conv
import
deform_conv2d
from
mmcv.utils
import
TORCH_VERSION
@
CONV_LAYERS
.
register_module
(
name
=
'SAC'
)
@
CONV_LAYERS
.
register_module
(
name
=
'SAC'
)
...
@@ -102,7 +103,10 @@ class SAConv2d(ConvAWS2d):
...
@@ -102,7 +103,10 @@ class SAConv2d(ConvAWS2d):
out_s
=
deform_conv2d
(
x
,
offset
,
weight
,
self
.
stride
,
self
.
padding
,
out_s
=
deform_conv2d
(
x
,
offset
,
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
1
)
self
.
dilation
,
self
.
groups
,
1
)
else
:
else
:
out_s
=
super
().
conv2d_forward
(
x
,
weight
)
if
TORCH_VERSION
<
'1.5.0'
or
TORCH_VERSION
==
'parrots'
:
out_s
=
super
().
conv2d_forward
(
x
,
weight
)
else
:
out_s
=
super
().
_conv_forward
(
x
,
weight
)
ori_p
=
self
.
padding
ori_p
=
self
.
padding
ori_d
=
self
.
dilation
ori_d
=
self
.
dilation
self
.
padding
=
tuple
(
3
*
p
for
p
in
self
.
padding
)
self
.
padding
=
tuple
(
3
*
p
for
p
in
self
.
padding
)
...
@@ -113,7 +117,10 @@ class SAConv2d(ConvAWS2d):
...
@@ -113,7 +117,10 @@ class SAConv2d(ConvAWS2d):
out_l
=
deform_conv2d
(
x
,
offset
,
weight
,
self
.
stride
,
self
.
padding
,
out_l
=
deform_conv2d
(
x
,
offset
,
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
1
)
self
.
dilation
,
self
.
groups
,
1
)
else
:
else
:
out_l
=
super
().
conv2d_forward
(
x
,
weight
)
if
TORCH_VERSION
<
'1.5.0'
or
TORCH_VERSION
==
'parrots'
:
out_l
=
super
().
conv2d_forward
(
x
,
weight
)
else
:
out_l
=
super
().
_conv_forward
(
x
,
weight
)
out
=
switch
*
out_s
+
(
1
-
switch
)
*
out_l
out
=
switch
*
out_s
+
(
1
-
switch
)
*
out_l
self
.
padding
=
ori_p
self
.
padding
=
ori_p
self
.
dilation
=
ori_d
self
.
dilation
=
ori_d
...
...
tests/test_ops/test_saconv.py
0 → 100644
View file @
7b18b977
import
pytest
import
torch
import
torch.nn
as
nn
from
mmcv.ops
import
SAConv2d
def
test_sacconv
():
# test with normal cast
x
=
torch
.
rand
(
1
,
3
,
256
,
256
)
saconv
=
SAConv2d
(
3
,
5
,
kernel_size
=
3
,
padding
=
1
)
sac_out
=
saconv
(
x
)
refer_conv
=
nn
.
Conv2d
(
3
,
5
,
kernel_size
=
3
,
padding
=
1
)
refer_out
=
refer_conv
(
x
)
assert
sac_out
.
shape
==
refer_out
.
shape
# test with dilation >= 2
dalited_saconv
=
SAConv2d
(
3
,
5
,
kernel_size
=
3
,
padding
=
2
,
dilation
=
2
)
dalited_sac_out
=
dalited_saconv
(
x
)
refer_conv
=
nn
.
Conv2d
(
3
,
5
,
kernel_size
=
3
,
padding
=
2
,
dilation
=
2
)
refer_out
=
refer_conv
(
x
)
assert
dalited_sac_out
.
shape
==
refer_out
.
shape
# test with deform
deform_saconv
=
SAConv2d
(
3
,
5
,
kernel_size
=
3
,
padding
=
1
,
use_deform
=
True
)
if
torch
.
cuda
.
is_available
():
x
=
torch
.
rand
(
1
,
3
,
256
,
256
).
cuda
()
deform_saconv
=
SAConv2d
(
3
,
5
,
kernel_size
=
3
,
padding
=
1
,
use_deform
=
True
).
cuda
()
deform_sac_out
=
deform_saconv
(
x
).
cuda
()
refer_conv
=
nn
.
Conv2d
(
3
,
5
,
kernel_size
=
3
,
padding
=
1
).
cuda
()
refer_out
=
refer_conv
(
x
)
assert
deform_sac_out
.
shape
==
refer_out
.
shape
else
:
with
pytest
.
raises
(
RuntimeError
):
# deform conv is not implemented on cpu
deform_saconv
(
x
)
# test with groups >= 2
x
=
torch
.
rand
(
1
,
4
,
256
,
256
)
group_saconv
=
SAConv2d
(
4
,
4
,
kernel_size
=
3
,
padding
=
1
,
groups
=
2
)
group_sac_out
=
group_saconv
(
x
)
refer_conv
=
nn
.
Conv2d
(
4
,
4
,
kernel_size
=
3
,
padding
=
1
,
groups
=
2
)
refer_out
=
refer_conv
(
x
)
assert
group_sac_out
.
shape
==
refer_out
.
shape
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment