Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
dc57735f
Commit
dc57735f
authored
Jan 21, 2019
by
yhcao6
Browse files
move bias check from module to function
parent
b1ba5939
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
46 deletions
+37
-46
mmdet/ops/dcn/functions/deform_conv.py
mmdet/ops/dcn/functions/deform_conv.py
+10
-7
mmdet/ops/dcn/functions/deform_pool.py
mmdet/ops/dcn/functions/deform_pool.py
+0
-6
mmdet/ops/dcn/modules/deform_conv.py
mmdet/ops/dcn/modules/deform_conv.py
+6
-8
mmdet/ops/dcn/modules/deform_pool.py
mmdet/ops/dcn/modules/deform_pool.py
+21
-25
No files found.
mmdet/ops/dcn/functions/deform_conv.py
View file @
dc57735f
...
...
@@ -107,17 +107,18 @@ class ModulatedDeformConvFunction(Function):
offset
,
mask
,
weight
,
bias
,
stride
,
padding
,
bias
=
None
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
deformable_groups
=
1
,
with_bias
=
False
):
deformable_groups
=
1
):
ctx
.
stride
=
stride
ctx
.
padding
=
padding
ctx
.
dilation
=
dilation
ctx
.
deformable_groups
=
deformable_groups
ctx
.
with_bias
=
with_bias
ctx
.
with_bias
=
bias
is
not
None
if
not
ctx
.
with_bias
:
bias
=
input
.
new_empty
(
1
)
# fake tensor
if
not
input
.
is_cuda
:
raise
NotImplementedError
if
weight
.
requires_grad
or
mask
.
requires_grad
or
offset
.
requires_grad
\
...
...
@@ -149,9 +150,11 @@ class ModulatedDeformConvFunction(Function):
grad_output
,
weight
.
shape
[
2
],
weight
.
shape
[
3
],
ctx
.
stride
,
ctx
.
stride
,
ctx
.
padding
,
ctx
.
padding
,
ctx
.
dilation
,
ctx
.
dilation
,
ctx
.
deformable_groups
,
ctx
.
with_bias
)
if
not
ctx
.
with_bias
:
grad_bias
=
None
return
(
grad_input
,
grad_offset
,
grad_mask
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
,
None
)
None
,
None
,
None
,
None
)
@
staticmethod
def
_infer_shape
(
ctx
,
input
,
weight
):
...
...
mmdet/ops/dcn/functions/deform_pool.py
View file @
dc57735f
...
...
@@ -43,9 +43,6 @@ class DeformRoIPoolingFunction(Function):
if
data
.
requires_grad
or
rois
.
requires_grad
or
offset
.
requires_grad
:
ctx
.
save_for_backward
(
data
,
rois
,
offset
)
# ctx.data = data
# ctx.rois = rois
# ctx.offset = offset
ctx
.
output_count
=
output_count
return
output
...
...
@@ -56,9 +53,6 @@ class DeformRoIPoolingFunction(Function):
raise
NotImplementedError
data
,
rois
,
offset
=
ctx
.
saved_tensors
# data = ctx.data
# rois = ctx.rois
# offset = ctx.offset
output_count
=
ctx
.
output_count
grad_input
=
torch
.
zeros_like
(
data
)
grad_offset
=
torch
.
zeros_like
(
offset
)
...
...
mmdet/ops/dcn/modules/deform_conv.py
View file @
dc57735f
...
...
@@ -71,7 +71,7 @@ class ModulatedDeformConv(nn.Module):
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
))
else
:
self
.
bias
=
nn
.
Parameter
(
torch
.
Tensor
([
0
]))
# fake tensor
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
...
...
@@ -80,14 +80,13 @@ class ModulatedDeformConv(nn.Module):
n
*=
k
stdv
=
1.
/
math
.
sqrt
(
n
)
self
.
weight
.
data
.
uniform_
(
-
stdv
,
stdv
)
if
self
.
with_bias
:
if
self
.
bias
is
not
None
:
self
.
bias
.
data
.
zero_
()
def
forward
(
self
,
input
,
offset
,
mask
):
return
modulated_deform_conv
(
input
,
offset
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
deformable_groups
,
self
.
with_bias
)
self
.
dilation
,
self
.
deformable_groups
)
class
ModulatedDeformConvPack
(
ModulatedDeformConv
):
...
...
@@ -110,8 +109,8 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
self
.
deformable_groups
*
3
*
self
.
kernel_size
[
0
]
*
self
.
kernel_size
[
1
],
kernel_size
=
self
.
kernel_size
,
stride
=
(
self
.
stride
,
self
.
stride
),
padding
=
(
self
.
padding
,
self
.
padding
),
stride
=
_pair
(
self
.
stride
),
padding
=
_pair
(
self
.
padding
),
bias
=
True
)
self
.
init_offset
()
...
...
@@ -126,5 +125,4 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
mask
=
torch
.
sigmoid
(
mask
)
return
modulated_deform_conv
(
input
,
offset
,
mask
,
self
.
weight
,
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
deformable_groups
,
self
.
with_bias
)
self
.
dilation
,
self
.
deformable_groups
)
mmdet/ops/dcn/modules/deform_pool.py
View file @
dc57735f
...
...
@@ -33,7 +33,7 @@ class DeformRoIPooling(nn.Module):
self
.
sample_per_part
,
self
.
trans_std
)
class
Modulated
DeformRoIPoolingPack
(
DeformRoIPooling
):
class
DeformRoIPoolingPack
(
DeformRoIPooling
):
def
__init__
(
self
,
spatial_scale
,
...
...
@@ -45,32 +45,22 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
sample_per_part
=
4
,
trans_std
=
.
0
,
deform_fc_dim
=
1024
):
super
(
Modulated
DeformRoIPoolingPack
,
self
).
__init__
(
spatial_scale
,
out_size
,
output_dim
,
no_trans
,
group_size
,
part_size
,
sample_per_part
,
trans_std
)
super
(
DeformRoIPoolingPack
,
self
).
__init__
(
spatial_scale
,
out_size
,
output_dim
,
no_trans
,
group_size
,
part_size
,
sample_per_part
,
trans_std
)
self
.
deform_fc_dim
=
deform_fc_dim
if
not
no_trans
:
self
.
offset_fc
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
out_size
*
self
.
out_size
*
self
.
output_dim
,
self
.
deform_fc_dim
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
out_size
*
self
.
out_size
*
self
.
output_dim
,
self
.
deform_fc_dim
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
deform_fc_dim
,
self
.
deform_fc_dim
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
deform_fc_dim
,
self
.
out_size
*
self
.
out_size
*
2
))
self
.
offset_fc
[
4
].
weight
.
data
.
zero_
()
self
.
offset_fc
[
4
].
bias
.
data
.
zero_
()
self
.
mask_fc
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
out_size
*
self
.
out_size
*
self
.
output_dim
,
self
.
deform_fc_dim
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
deform_fc_dim
,
self
.
out_size
*
self
.
out_size
*
1
),
nn
.
Sigmoid
())
self
.
mask_fc
[
2
].
weight
.
data
.
zero_
()
self
.
mask_fc
[
2
].
bias
.
data
.
zero_
()
def
forward
(
self
,
data
,
rois
):
if
self
.
no_trans
:
...
...
@@ -84,12 +74,10 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
self
.
sample_per_part
,
self
.
trans_std
)
offset
=
self
.
offset_fc
(
x
.
view
(
n
,
-
1
))
offset
=
offset
.
view
(
n
,
2
,
self
.
out_size
,
self
.
out_size
)
mask
=
self
.
mask_fc
(
x
.
view
(
n
,
-
1
))
mask
=
mask
.
view
(
n
,
1
,
self
.
out_size
,
self
.
out_size
)
feat
=
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
self
.
output_dim
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
*
mask
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
return
feat
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
...
...
@@ -97,7 +85,7 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
self
.
sample_per_part
,
self
.
trans_std
)
class
DeformRoIPoolingPack
(
DeformRoIPooling
):
class
Modulated
DeformRoIPoolingPack
(
DeformRoIPooling
):
def
__init__
(
self
,
spatial_scale
,
...
...
@@ -109,7 +97,7 @@ class DeformRoIPoolingPack(DeformRoIPooling):
sample_per_part
=
4
,
trans_std
=
.
0
,
deform_fc_dim
=
1024
):
super
(
DeformRoIPoolingPack
,
self
).
__init__
(
super
(
Modulated
DeformRoIPoolingPack
,
self
).
__init__
(
spatial_scale
,
out_size
,
output_dim
,
no_trans
,
group_size
,
part_size
,
sample_per_part
,
trans_std
)
...
...
@@ -117,15 +105,21 @@ class DeformRoIPoolingPack(DeformRoIPooling):
if
not
no_trans
:
self
.
offset_fc
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
out_size
*
self
.
out_size
*
self
.
output_dim
,
self
.
deform_fc_dim
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
out_size
*
self
.
out_size
*
self
.
output_dim
,
self
.
deform_fc_dim
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
deform_fc_dim
,
self
.
deform_fc_dim
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
deform_fc_dim
,
self
.
out_size
*
self
.
out_size
*
2
))
self
.
offset_fc
[
4
].
weight
.
data
.
zero_
()
self
.
offset_fc
[
4
].
bias
.
data
.
zero_
()
self
.
mask_fc
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
out_size
*
self
.
out_size
*
self
.
output_dim
,
self
.
deform_fc_dim
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
self
.
deform_fc_dim
,
self
.
out_size
*
self
.
out_size
*
1
),
nn
.
Sigmoid
())
self
.
mask_fc
[
2
].
weight
.
data
.
zero_
()
self
.
mask_fc
[
2
].
bias
.
data
.
zero_
()
def
forward
(
self
,
data
,
rois
):
if
self
.
no_trans
:
...
...
@@ -139,10 +133,12 @@ class DeformRoIPoolingPack(DeformRoIPooling):
self
.
sample_per_part
,
self
.
trans_std
)
offset
=
self
.
offset_fc
(
x
.
view
(
n
,
-
1
))
offset
=
offset
.
view
(
n
,
2
,
self
.
out_size
,
self
.
out_size
)
mask
=
self
.
mask_fc
(
x
.
view
(
n
,
-
1
))
mask
=
mask
.
view
(
n
,
1
,
self
.
out_size
,
self
.
out_size
)
feat
=
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
self
.
output_dim
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
*
mask
return
feat
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
out_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment