Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
86cc430a
Unverified
Commit
86cc430a
authored
Jul 29, 2019
by
Kai Chen
Committed by
GitHub
Jul 29, 2019
Browse files
Restructure the ops directory (#1073)
* restructure the ops directory * add some repr strings
parent
8387aba8
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
329 additions
and
325 deletions
+329
-325
mmdet/ops/__init__.py
mmdet/ops/__init__.py
+1
-1
mmdet/ops/context_block.py
mmdet/ops/context_block.py
+0
-0
mmdet/ops/dcn/__init__.py
mmdet/ops/dcn/__init__.py
+5
-6
mmdet/ops/dcn/deform_conv.py
mmdet/ops/dcn/deform_conv.py
+157
-1
mmdet/ops/dcn/deform_pool.py
mmdet/ops/dcn/deform_pool.py
+95
-22
mmdet/ops/dcn/functions/__init__.py
mmdet/ops/dcn/functions/__init__.py
+0
-0
mmdet/ops/dcn/functions/deform_pool.py
mmdet/ops/dcn/functions/deform_pool.py
+0
-69
mmdet/ops/dcn/modules/__init__.py
mmdet/ops/dcn/modules/__init__.py
+0
-0
mmdet/ops/dcn/modules/deform_conv.py
mmdet/ops/dcn/modules/deform_conv.py
+0
-157
mmdet/ops/gcb/__init__.py
mmdet/ops/gcb/__init__.py
+0
-5
mmdet/ops/masked_conv/__init__.py
mmdet/ops/masked_conv/__init__.py
+1
-2
mmdet/ops/masked_conv/functions/__init__.py
mmdet/ops/masked_conv/functions/__init__.py
+0
-0
mmdet/ops/masked_conv/masked_conv.py
mmdet/ops/masked_conv/masked_conv.py
+34
-1
mmdet/ops/masked_conv/modules/__init__.py
mmdet/ops/masked_conv/modules/__init__.py
+0
-0
mmdet/ops/masked_conv/modules/masked_conv.py
mmdet/ops/masked_conv/modules/masked_conv.py
+0
-30
mmdet/ops/roi_align/__init__.py
mmdet/ops/roi_align/__init__.py
+1
-2
mmdet/ops/roi_align/functions/__init__.py
mmdet/ops/roi_align/functions/__init__.py
+0
-0
mmdet/ops/roi_align/modules/__init__.py
mmdet/ops/roi_align/modules/__init__.py
+0
-0
mmdet/ops/roi_align/modules/roi_align.py
mmdet/ops/roi_align/modules/roi_align.py
+0
-28
mmdet/ops/roi_align/roi_align.py
mmdet/ops/roi_align/roi_align.py
+35
-1
No files found.
mmdet/ops/__init__.py
View file @
86cc430a
...
...
@@ -2,7 +2,7 @@ from .dcn import (DeformConv, DeformConvPack, ModulatedDeformConv,
ModulatedDeformConvPack
,
DeformRoIPooling
,
DeformRoIPoolingPack
,
ModulatedDeformRoIPoolingPack
,
deform_conv
,
modulated_deform_conv
,
deform_roi_pooling
)
from
.
gcb
import
ContextBlock
from
.
context_block
import
ContextBlock
from
.nms
import
nms
,
soft_nms
from
.roi_align
import
RoIAlign
,
roi_align
from
.roi_pool
import
RoIPool
,
roi_pool
...
...
mmdet/ops/
gcb/
context_block.py
→
mmdet/ops/context_block.py
View file @
86cc430a
File moved
mmdet/ops/dcn/__init__.py
View file @
86cc430a
from
.functions.deform_conv
import
deform_conv
,
modulated_deform_conv
from
.functions.deform_pool
import
deform_roi_pooling
from
.modules.deform_conv
import
(
DeformConv
,
ModulatedDeformConv
,
DeformConvPack
,
ModulatedDeformConvPack
)
from
.modules.deform_pool
import
(
DeformRoIPooling
,
DeformRoIPoolingPack
,
ModulatedDeformRoIPoolingPack
)
from
.deform_conv
import
(
deform_conv
,
modulated_deform_conv
,
DeformConv
,
DeformConvPack
,
ModulatedDeformConv
,
ModulatedDeformConvPack
)
from
.deform_pool
import
(
deform_roi_pooling
,
DeformRoIPooling
,
DeformRoIPoolingPack
,
ModulatedDeformRoIPoolingPack
)
__all__
=
[
'DeformConv'
,
'DeformConvPack'
,
'ModulatedDeformConv'
,
...
...
mmdet/ops/dcn/
functions/
deform_conv.py
→
mmdet/ops/dcn/deform_conv.py
View file @
86cc430a
import
math
import
torch
import
torch.nn
as
nn
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
from
torch.nn.modules.utils
import
_pair
from
.
.
import
deform_conv_cuda
from
.
import
deform_conv_cuda
class
DeformConvFunction
(
Function
):
...
...
@@ -52,6 +56,7 @@ class DeformConvFunction(Function):
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
input
,
offset
,
weight
=
ctx
.
saved_tensors
...
...
@@ -143,6 +148,7 @@ class ModulatedDeformConvFunction(Function):
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
if
not
grad_output
.
is_cuda
:
raise
NotImplementedError
...
...
@@ -179,3 +185,153 @@ class ModulatedDeformConvFunction(Function):
deform_conv
=
DeformConvFunction
.
apply
modulated_deform_conv
=
ModulatedDeformConvFunction
.
apply
class
DeformConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
deformable_groups
=
1
,
bias
=
False
):
super
(
DeformConv
,
self
).
__init__
()
assert
not
bias
assert
in_channels
%
groups
==
0
,
\
'in_channels {} cannot be divisible by groups {}'
.
format
(
in_channels
,
groups
)
assert
out_channels
%
groups
==
0
,
\
'out_channels {} cannot be divisible by groups {}'
.
format
(
out_channels
,
groups
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
_pair
(
kernel_size
)
self
.
stride
=
_pair
(
stride
)
self
.
padding
=
_pair
(
padding
)
self
.
dilation
=
_pair
(
dilation
)
self
.
groups
=
groups
self
.
deformable_groups
=
deformable_groups
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
,
in_channels
//
self
.
groups
,
*
self
.
kernel_size
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
n
=
self
.
in_channels
for
k
in
self
.
kernel_size
:
n
*=
k
stdv
=
1.
/
math
.
sqrt
(
n
)
self
.
weight
.
data
.
uniform_
(
-
stdv
,
stdv
)
def
forward
(
self
,
x
,
offset
):
return
deform_conv
(
x
,
offset
,
self
.
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
class
DeformConvPack
(
DeformConv
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DeformConvPack
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
conv_offset
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
deformable_groups
*
2
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
kernel_size
=
self
.
kernel_size
,
stride
=
_pair
(
self
.
stride
),
padding
=
_pair
(
self
.
padding
),
bias
=
True
)
self
.
init_offset
()
def
init_offset
(
self
):
self
.
conv_offset
.
weight
.
data
.
zero_
()
self
.
conv_offset
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
offset
=
self
.
conv_offset
(
x
)
return
deform_conv
(
x
,
offset
,
self
.
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
class
ModulatedDeformConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
deformable_groups
=
1
,
bias
=
True
):
super
(
ModulatedDeformConv
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
_pair
(
kernel_size
)
self
.
stride
=
stride
self
.
padding
=
padding
self
.
dilation
=
dilation
self
.
groups
=
groups
self
.
deformable_groups
=
deformable_groups
self
.
with_bias
=
bias
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
,
in_channels
//
groups
,
*
self
.
kernel_size
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
n
=
self
.
in_channels
for
k
in
self
.
kernel_size
:
n
*=
k
stdv
=
1.
/
math
.
sqrt
(
n
)
self
.
weight
.
data
.
uniform_
(
-
stdv
,
stdv
)
if
self
.
bias
is
not
None
:
self
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
offset
,
mask
):
return
modulated_deform_conv
(
x
,
offset
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
class
ModulatedDeformConvPack
(
ModulatedDeformConv
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
ModulatedDeformConvPack
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
conv_offset_mask
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
deformable_groups
*
3
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
kernel_size
=
self
.
kernel_size
,
stride
=
_pair
(
self
.
stride
),
padding
=
_pair
(
self
.
padding
),
bias
=
True
)
self
.
init_offset
()
def
init_offset
(
self
):
self
.
conv_offset_mask
.
weight
.
data
.
zero_
()
self
.
conv_offset_mask
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
out
=
self
.
conv_offset_mask
(
x
)
o1
,
o2
,
mask
=
torch
.
chunk
(
out
,
3
,
dim
=
1
)
offset
=
torch
.
cat
((
o1
,
o2
),
dim
=
1
)
mask
=
torch
.
sigmoid
(
mask
)
return
modulated_deform_conv
(
x
,
offset
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
mmdet/ops/dcn/
modules/
deform_pool.py
→
mmdet/ops/dcn/deform_pool.py
View file @
86cc430a
from
torch
import
nn
import
torch
import
torch.nn
as
nn
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
from
..functions.deform_pool
import
deform_roi_pooling
from
.
import
deform_pool_cuda
class
DeformRoIPoolingFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
data
,
rois
,
offset
,
spatial_scale
,
out_size
,
out_channels
,
no_trans
,
group_size
=
1
,
part_size
=
None
,
sample_per_part
=
4
,
trans_std
=
.
0
):
ctx
.
spatial_scale
=
spatial_scale
ctx
.
out_size
=
out_size
ctx
.
out_channels
=
out_channels
ctx
.
no_trans
=
no_trans
ctx
.
group_size
=
group_size
ctx
.
part_size
=
out_size
if
part_size
is
None
else
part_size
ctx
.
sample_per_part
=
sample_per_part
ctx
.
trans_std
=
trans_std
assert
0.0
<=
ctx
.
trans_std
<=
1.0
if
not
data
.
is_cuda
:
raise
NotImplementedError
n
=
rois
.
shape
[
0
]
output
=
data
.
new_empty
(
n
,
out_channels
,
out_size
,
out_size
)
output_count
=
data
.
new_empty
(
n
,
out_channels
,
out_size
,
out_size
)
deform_pool_cuda
.
deform_psroi_pooling_cuda_forward
(
data
,
rois
,
offset
,
output
,
output_count
,
ctx
.
no_trans
,
ctx
.
spatial_scale
,
ctx
.
out_channels
,
ctx
.
group_size
,
ctx
.
out_size
,
ctx
.
part_size
,
ctx
.
sample_per_part
,
ctx
.
trans_std
)
if
data
.
requires_grad
or
rois
.
requires_grad
or
offset
.
requires_grad
:
ctx
.
save_for_backward
(
data
,
rois
,
offset
)
ctx
.
output_count
=
output_count
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
if
not
grad_output
.
is_cuda
:
raise
NotImplementedError
data
,
rois
,
offset
=
ctx
.
saved_tensors
output_count
=
ctx
.
output_count
grad_input
=
torch
.
zeros_like
(
data
)
grad_rois
=
None
grad_offset
=
torch
.
zeros_like
(
offset
)
deform_pool_cuda
.
deform_psroi_pooling_cuda_backward
(
grad_output
,
data
,
rois
,
offset
,
output_count
,
grad_input
,
grad_offset
,
ctx
.
no_trans
,
ctx
.
spatial_scale
,
ctx
.
out_channels
,
ctx
.
group_size
,
ctx
.
out_size
,
ctx
.
part_size
,
ctx
.
sample_per_part
,
ctx
.
trans_std
)
return
(
grad_input
,
grad_rois
,
grad_offset
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
deform_roi_pooling
=
DeformRoIPoolingFunction
.
apply
class
DeformRoIPooling
(
nn
.
Module
):
...
...
@@ -27,10 +96,11 @@ class DeformRoIPooling(nn.Module):
def
forward
(
self
,
data
,
rois
,
offset
):
if
self
.
no_trans
:
offset
=
data
.
new_empty
(
0
)
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
self
.
out_channels
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
self
.
out_channels
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
class
DeformRoIPoolingPack
(
DeformRoIPooling
):
...
...
@@ -73,10 +143,11 @@ class DeformRoIPoolingPack(DeformRoIPooling):
assert
data
.
size
(
1
)
==
self
.
out_channels
if
self
.
no_trans
:
offset
=
data
.
new_empty
(
0
)
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
self
.
out_channels
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
self
.
out_channels
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
else
:
n
=
rois
.
shape
[
0
]
offset
=
data
.
new_empty
(
0
)
...
...
@@ -86,10 +157,11 @@ class DeformRoIPoolingPack(DeformRoIPooling):
self
.
sample_per_part
,
self
.
trans_std
)
offset
=
self
.
offset_fc
(
x
.
view
(
n
,
-
1
))
offset
=
offset
.
view
(
n
,
2
,
self
.
out_size
,
self
.
out_size
)
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
self
.
out_channels
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
self
.
out_channels
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
class
ModulatedDeformRoIPoolingPack
(
DeformRoIPooling
):
...
...
@@ -106,9 +178,9 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
num_offset_fcs
=
3
,
num_mask_fcs
=
2
,
deform_fc_channels
=
1024
):
super
(
ModulatedDeformRoIPoolingPack
,
self
).
__init__
(
spatial_scale
,
out_size
,
out_channels
,
no_trans
,
group_size
,
part_size
,
sample_per_part
,
trans_std
)
super
(
ModulatedDeformRoIPoolingPack
,
self
).
__init__
(
spatial_scale
,
out_size
,
out_channels
,
no_trans
,
group_size
,
part_size
,
sample_per_part
,
trans_std
)
self
.
num_offset_fcs
=
num_offset_fcs
self
.
num_mask_fcs
=
num_mask_fcs
...
...
@@ -151,10 +223,11 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
assert
data
.
size
(
1
)
==
self
.
out_channels
if
self
.
no_trans
:
offset
=
data
.
new_empty
(
0
)
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
self
.
out_channels
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
self
.
out_channels
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
else
:
n
=
rois
.
shape
[
0
]
offset
=
data
.
new_empty
(
0
)
...
...
mmdet/ops/dcn/functions/__init__.py
deleted
100644 → 0
View file @
8387aba8
mmdet/ops/dcn/functions/deform_pool.py
deleted
100644 → 0
View file @
8387aba8
import
torch
from
torch.autograd
import
Function
from
..
import
deform_pool_cuda
class
DeformRoIPoolingFunction
(
Function
):
@
staticmethod
def
forward
(
ctx
,
data
,
rois
,
offset
,
spatial_scale
,
out_size
,
out_channels
,
no_trans
,
group_size
=
1
,
part_size
=
None
,
sample_per_part
=
4
,
trans_std
=
.
0
):
ctx
.
spatial_scale
=
spatial_scale
ctx
.
out_size
=
out_size
ctx
.
out_channels
=
out_channels
ctx
.
no_trans
=
no_trans
ctx
.
group_size
=
group_size
ctx
.
part_size
=
out_size
if
part_size
is
None
else
part_size
ctx
.
sample_per_part
=
sample_per_part
ctx
.
trans_std
=
trans_std
assert
0.0
<=
ctx
.
trans_std
<=
1.0
if
not
data
.
is_cuda
:
raise
NotImplementedError
n
=
rois
.
shape
[
0
]
output
=
data
.
new_empty
(
n
,
out_channels
,
out_size
,
out_size
)
output_count
=
data
.
new_empty
(
n
,
out_channels
,
out_size
,
out_size
)
deform_pool_cuda
.
deform_psroi_pooling_cuda_forward
(
data
,
rois
,
offset
,
output
,
output_count
,
ctx
.
no_trans
,
ctx
.
spatial_scale
,
ctx
.
out_channels
,
ctx
.
group_size
,
ctx
.
out_size
,
ctx
.
part_size
,
ctx
.
sample_per_part
,
ctx
.
trans_std
)
if
data
.
requires_grad
or
rois
.
requires_grad
or
offset
.
requires_grad
:
ctx
.
save_for_backward
(
data
,
rois
,
offset
)
ctx
.
output_count
=
output_count
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
not
grad_output
.
is_cuda
:
raise
NotImplementedError
data
,
rois
,
offset
=
ctx
.
saved_tensors
output_count
=
ctx
.
output_count
grad_input
=
torch
.
zeros_like
(
data
)
grad_rois
=
None
grad_offset
=
torch
.
zeros_like
(
offset
)
deform_pool_cuda
.
deform_psroi_pooling_cuda_backward
(
grad_output
,
data
,
rois
,
offset
,
output_count
,
grad_input
,
grad_offset
,
ctx
.
no_trans
,
ctx
.
spatial_scale
,
ctx
.
out_channels
,
ctx
.
group_size
,
ctx
.
out_size
,
ctx
.
part_size
,
ctx
.
sample_per_part
,
ctx
.
trans_std
)
return
(
grad_input
,
grad_rois
,
grad_offset
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
deform_roi_pooling
=
DeformRoIPoolingFunction
.
apply
mmdet/ops/dcn/modules/__init__.py
deleted
100644 → 0
View file @
8387aba8
mmdet/ops/dcn/modules/deform_conv.py
deleted
100644 → 0
View file @
8387aba8
import
math
import
torch
import
torch.nn
as
nn
from
torch.nn.modules.utils
import
_pair
from
..functions.deform_conv
import
deform_conv
,
modulated_deform_conv
class
DeformConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
deformable_groups
=
1
,
bias
=
False
):
super
(
DeformConv
,
self
).
__init__
()
assert
not
bias
assert
in_channels
%
groups
==
0
,
\
'in_channels {} cannot be divisible by groups {}'
.
format
(
in_channels
,
groups
)
assert
out_channels
%
groups
==
0
,
\
'out_channels {} cannot be divisible by groups {}'
.
format
(
out_channels
,
groups
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
_pair
(
kernel_size
)
self
.
stride
=
_pair
(
stride
)
self
.
padding
=
_pair
(
padding
)
self
.
dilation
=
_pair
(
dilation
)
self
.
groups
=
groups
self
.
deformable_groups
=
deformable_groups
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
,
in_channels
//
self
.
groups
,
*
self
.
kernel_size
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
n
=
self
.
in_channels
for
k
in
self
.
kernel_size
:
n
*=
k
stdv
=
1.
/
math
.
sqrt
(
n
)
self
.
weight
.
data
.
uniform_
(
-
stdv
,
stdv
)
def
forward
(
self
,
x
,
offset
):
return
deform_conv
(
x
,
offset
,
self
.
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
class
DeformConvPack
(
DeformConv
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
DeformConvPack
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
conv_offset
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
deformable_groups
*
2
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
kernel_size
=
self
.
kernel_size
,
stride
=
_pair
(
self
.
stride
),
padding
=
_pair
(
self
.
padding
),
bias
=
True
)
self
.
init_offset
()
def
init_offset
(
self
):
self
.
conv_offset
.
weight
.
data
.
zero_
()
self
.
conv_offset
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
offset
=
self
.
conv_offset
(
x
)
return
deform_conv
(
x
,
offset
,
self
.
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
class
ModulatedDeformConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
deformable_groups
=
1
,
bias
=
True
):
super
(
ModulatedDeformConv
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
_pair
(
kernel_size
)
self
.
stride
=
stride
self
.
padding
=
padding
self
.
dilation
=
dilation
self
.
groups
=
groups
self
.
deformable_groups
=
deformable_groups
self
.
with_bias
=
bias
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
,
in_channels
//
groups
,
*
self
.
kernel_size
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
n
=
self
.
in_channels
for
k
in
self
.
kernel_size
:
n
*=
k
stdv
=
1.
/
math
.
sqrt
(
n
)
self
.
weight
.
data
.
uniform_
(
-
stdv
,
stdv
)
if
self
.
bias
is
not
None
:
self
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
offset
,
mask
):
return
modulated_deform_conv
(
x
,
offset
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
class
ModulatedDeformConvPack
(
ModulatedDeformConv
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
ModulatedDeformConvPack
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
conv_offset_mask
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
deformable_groups
*
3
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
kernel_size
=
self
.
kernel_size
,
stride
=
_pair
(
self
.
stride
),
padding
=
_pair
(
self
.
padding
),
bias
=
True
)
self
.
init_offset
()
def
init_offset
(
self
):
self
.
conv_offset_mask
.
weight
.
data
.
zero_
()
self
.
conv_offset_mask
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
out
=
self
.
conv_offset_mask
(
x
)
o1
,
o2
,
mask
=
torch
.
chunk
(
out
,
3
,
dim
=
1
)
offset
=
torch
.
cat
((
o1
,
o2
),
dim
=
1
)
mask
=
torch
.
sigmoid
(
mask
)
return
modulated_deform_conv
(
x
,
offset
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
,
self
.
deformable_groups
)
mmdet/ops/gcb/__init__.py
deleted
100644 → 0
View file @
8387aba8
from
.context_block
import
ContextBlock
__all__
=
[
'ContextBlock'
,
]
mmdet/ops/masked_conv/__init__.py
View file @
86cc430a
from
.functions.masked_conv
import
masked_conv2d
from
.modules.masked_conv
import
MaskedConv2d
from
.masked_conv
import
masked_conv2d
,
MaskedConv2d
__all__
=
[
'masked_conv2d'
,
'MaskedConv2d'
]
mmdet/ops/masked_conv/functions/__init__.py
deleted
100644 → 0
View file @
8387aba8
mmdet/ops/masked_conv/
functions/
masked_conv.py
→
mmdet/ops/masked_conv/masked_conv.py
View file @
86cc430a
import
math
import
torch
import
torch.nn
as
nn
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
from
torch.nn.modules.utils
import
_pair
from
..
import
masked_conv2d_cuda
from
.
import
masked_conv2d_cuda
class
MaskedConv2dFunction
(
Function
):
...
...
@@ -49,8 +53,37 @@ class MaskedConv2dFunction(Function):
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
return
(
None
,
)
*
5
masked_conv2d
=
MaskedConv2dFunction
.
apply
class
MaskedConv2d
(
nn
.
Conv2d
):
"""A MaskedConv2d which inherits the official Conv2d.
The masked forward doesn't implement the backward function and only
supports the stride parameter to be 1 currently.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
bias
=
True
):
super
(
MaskedConv2d
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
bias
)
def
forward
(
self
,
input
,
mask
=
None
):
if
mask
is
None
:
# fallback to the normal Conv2d
return
super
(
MaskedConv2d
,
self
).
forward
(
input
)
else
:
return
masked_conv2d
(
input
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
padding
)
mmdet/ops/masked_conv/modules/__init__.py
deleted
100644 → 0
View file @
8387aba8
mmdet/ops/masked_conv/modules/masked_conv.py
deleted
100644 → 0
View file @
8387aba8
import
torch.nn
as
nn
from
..functions.masked_conv
import
masked_conv2d
class
MaskedConv2d
(
nn
.
Conv2d
):
"""A MaskedConv2d which inherits the official Conv2d.
The masked forward doesn't implement the backward function and only
supports the stride parameter to be 1 currently.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
bias
=
True
):
super
(
MaskedConv2d
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
,
groups
,
bias
)
def
forward
(
self
,
input
,
mask
=
None
):
if
mask
is
None
:
# fallback to the normal Conv2d
return
super
(
MaskedConv2d
,
self
).
forward
(
input
)
else
:
return
masked_conv2d
(
input
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
padding
)
mmdet/ops/roi_align/__init__.py
View file @
86cc430a
from
.functions.roi_align
import
roi_align
from
.modules.roi_align
import
RoIAlign
from
.roi_align
import
roi_align
,
RoIAlign
__all__
=
[
'roi_align'
,
'RoIAlign'
]
mmdet/ops/roi_align/functions/__init__.py
deleted
100644 → 0
View file @
8387aba8
mmdet/ops/roi_align/modules/__init__.py
deleted
100644 → 0
View file @
8387aba8
mmdet/ops/roi_align/modules/roi_align.py
deleted
100644 → 0
View file @
8387aba8
import
torch.nn
as
nn
from
torch.nn.modules.utils
import
_pair
from
..functions.roi_align
import
roi_align
class
RoIAlign
(
nn
.
Module
):
def
__init__
(
self
,
out_size
,
spatial_scale
,
sample_num
=
0
,
use_torchvision
=
False
):
super
(
RoIAlign
,
self
).
__init__
()
self
.
out_size
=
out_size
self
.
spatial_scale
=
float
(
spatial_scale
)
self
.
sample_num
=
int
(
sample_num
)
self
.
use_torchvision
=
use_torchvision
def
forward
(
self
,
features
,
rois
):
if
self
.
use_torchvision
:
from
torchvision.ops
import
roi_align
as
tv_roi_align
return
tv_roi_align
(
features
,
rois
,
_pair
(
self
.
out_size
),
self
.
spatial_scale
,
self
.
sample_num
)
else
:
return
roi_align
(
features
,
rois
,
self
.
out_size
,
self
.
spatial_scale
,
self
.
sample_num
)
mmdet/ops/roi_align/
functions/
roi_align.py
→
mmdet/ops/roi_align/roi_align.py
View file @
86cc430a
import
torch.nn
as
nn
from
torch.autograd
import
Function
from
torch.autograd.function
import
once_differentiable
from
torch.nn.modules.utils
import
_pair
from
.
.
import
roi_align_cuda
from
.
import
roi_align_cuda
class
RoIAlignFunction
(
Function
):
...
...
@@ -28,6 +30,7 @@ class RoIAlignFunction(Function):
return
output
@
staticmethod
@
once_differentiable
def
backward
(
ctx
,
grad_output
):
feature_size
=
ctx
.
feature_size
spatial_scale
=
ctx
.
spatial_scale
...
...
@@ -51,3 +54,34 @@ class RoIAlignFunction(Function):
roi_align
=
RoIAlignFunction
.
apply
class
RoIAlign
(
nn
.
Module
):
def
__init__
(
self
,
out_size
,
spatial_scale
,
sample_num
=
0
,
use_torchvision
=
False
):
super
(
RoIAlign
,
self
).
__init__
()
self
.
out_size
=
out_size
self
.
spatial_scale
=
float
(
spatial_scale
)
self
.
sample_num
=
int
(
sample_num
)
self
.
use_torchvision
=
use_torchvision
def
forward
(
self
,
features
,
rois
):
if
self
.
use_torchvision
:
from
torchvision.ops
import
roi_align
as
tv_roi_align
return
tv_roi_align
(
features
,
rois
,
_pair
(
self
.
out_size
),
self
.
spatial_scale
,
self
.
sample_num
)
else
:
return
roi_align
(
features
,
rois
,
self
.
out_size
,
self
.
spatial_scale
,
self
.
sample_num
)
def
__repr__
(
self
):
format_str
=
self
.
__class__
.
__name__
format_str
+=
'(out_size={}, spatial_scale={}, sample_num={}'
.
format
(
self
.
out_size
,
self
.
spatial_scale
,
self
.
sample_num
)
format_str
+=
', use_torchvision={})'
.
format
(
self
.
use_torchvision
)
return
format_str
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment