Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
6bd60eac
"vscode:/vscode.git/clone" did not exist on "a1eea964b6d7a19ab50aa6418094ae74f8cc83bb"
Commit
6bd60eac
authored
Jan 15, 2019
by
yhcao6
Browse files
refactor dcn python interface
parent
ef86c404
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
216 additions
and
183 deletions
+216
-183
compile.sh
compile.sh
+13
-0
mmdet/ops/dcn/functions/deform_conv.py
mmdet/ops/dcn/functions/deform_conv.py
+60
-69
mmdet/ops/dcn/functions/modulated_dcn_func.py
mmdet/ops/dcn/functions/modulated_dcn_func.py
+109
-84
mmdet/ops/dcn/modules/deform_conv.py
mmdet/ops/dcn/modules/deform_conv.py
+6
-5
mmdet/ops/dcn/modules/modulated_dcn.py
mmdet/ops/dcn/modules/modulated_dcn.py
+26
-23
mmdet/ops/dcn/setup_modulated.py
mmdet/ops/dcn/setup_modulated.py
+2
-2
No files found.
compile.sh
View file @
6bd60eac
...
@@ -20,3 +20,16 @@ echo "Building nms op..."
...
@@ -20,3 +20,16 @@ echo "Building nms op..."
cd
../nms
cd
../nms
make clean
make clean
make
PYTHON
=
${
PYTHON
}
make
PYTHON
=
${
PYTHON
}
echo
"Building nms op..."
cd
../nms
make clean
make
PYTHON
=
${
PYTHON
}
echo
"Building dcn..."
cd
../dcn
if
[
-d
"build"
]
;
then
rm
-r
build
fi
$PYTHON
setup.py build_ext
--inplace
$PYTHON
setup_modulated.py build_ext
--inplace
mmdet/ops/dcn/functions/deform_conv.py
View file @
6bd60eac
...
@@ -2,49 +2,37 @@ import torch
...
@@ -2,49 +2,37 @@ import torch
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
torch.nn.modules.utils
import
_pair
from
torch.nn.modules.utils
import
_pair
from
..
import
deform_conv
from
..
import
deform_conv_cuda
def
deform_conv_function
(
input
,
offset
,
weight
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
deform_groups
=
1
,
im2col_step
=
64
):
if
input
is
not
None
and
input
.
dim
()
!=
4
:
raise
ValueError
(
"Expected 4D tensor as input, got {}D tensor instead."
.
format
(
input
.
dim
()))
f
=
DeformConvFunction
(
_pair
(
stride
),
_pair
(
padding
),
_pair
(
dilation
),
deform_groups
,
im2col_step
)
return
f
(
input
,
offset
,
weight
)
class
DeformConvFunction
(
Function
):
class
DeformConvFunction
(
Function
):
def
__init__
(
self
,
@
staticmethod
stride
,
def
forward
(
ctx
,
padding
,
input
,
dilation
,
offset
,
deformable_groups
=
1
,
weight
,
im2col_step
=
64
):
stride
=
1
,
super
(
DeformConvFunction
,
self
).
__init__
()
padding
=
0
,
self
.
stride
=
stride
dilation
=
1
,
self
.
padding
=
padding
deformable_groups
=
1
,
self
.
dilation
=
dilation
im2col_step
=
64
):
self
.
deformable_groups
=
deformable_groups
if
input
is
not
None
and
input
.
dim
()
!=
4
:
self
.
im2col_step
=
im2col_step
raise
ValueError
(
"Expected 4D tensor as input, got {}D tensor instead."
.
format
(
input
.
dim
()))
ctx
.
stride
=
_pair
(
stride
)
ctx
.
padding
=
_pair
(
padding
)
ctx
.
dilation
=
_pair
(
dilation
)
ctx
.
deformable_groups
=
deformable_groups
ctx
.
im2col_step
=
im2col_step
def
forward
(
self
,
input
,
offset
,
weight
):
ctx
.
save_for_backward
(
input
,
offset
,
weight
)
self
.
save_for_backward
(
input
,
offset
,
weight
)
output
=
input
.
new
(
*
self
.
_output_size
(
input
,
weight
))
output
=
input
.
new
(
*
DeformConvFunction
.
_output_size
(
input
,
weight
,
ctx
.
padding
,
ctx
.
dilation
,
ctx
.
stride
))
self
.
bufs_
=
[
input
.
new
(),
input
.
new
()]
# columns, ones
ctx
.
bufs_
=
[
input
.
new
(),
input
.
new
()]
# columns, ones
if
not
input
.
is_cuda
:
if
not
input
.
is_cuda
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -56,18 +44,19 @@ class DeformConvFunction(Function):
...
@@ -56,18 +44,19 @@ class DeformConvFunction(Function):
if
not
isinstance
(
input
,
torch
.
cuda
.
FloatTensor
):
if
not
isinstance
(
input
,
torch
.
cuda
.
FloatTensor
):
raise
NotImplementedError
raise
NotImplementedError
cur_im2col_step
=
min
(
self
.
im2col_step
,
input
.
shape
[
0
])
cur_im2col_step
=
min
(
ctx
.
im2col_step
,
input
.
shape
[
0
])
assert
(
input
.
shape
[
0
]
%
assert
(
input
.
shape
[
0
]
%
cur_im2col_step
)
==
0
,
'im2col step must divide batchsize'
cur_im2col_step
)
==
0
,
'im2col step must divide batchsize'
deform_conv
.
deform_conv_forward_cuda
(
deform_conv
_cuda
.
deform_conv_forward_cuda
(
input
,
weight
,
offset
,
output
,
self
.
bufs_
[
0
],
self
.
bufs_
[
1
],
input
,
weight
,
offset
,
output
,
ctx
.
bufs_
[
0
],
ctx
.
bufs_
[
1
],
weight
.
size
(
3
),
weight
.
size
(
2
),
self
.
stride
[
1
],
self
.
stride
[
0
],
weight
.
size
(
3
),
weight
.
size
(
2
),
ctx
.
stride
[
1
],
ctx
.
stride
[
0
],
self
.
padding
[
1
],
self
.
padding
[
0
],
self
.
dilation
[
1
],
ctx
.
padding
[
1
],
ctx
.
padding
[
0
],
ctx
.
dilation
[
1
],
self
.
dilation
[
0
],
self
.
deformable_groups
,
cur_im2col_step
)
ctx
.
dilation
[
0
],
ctx
.
deformable_groups
,
cur_im2col_step
)
return
output
return
output
def
backward
(
self
,
grad_output
):
@
staticmethod
input
,
offset
,
weight
=
self
.
saved_tensors
def
backward
(
ctx
,
grad_output
):
input
,
offset
,
weight
=
ctx
.
saved_tensors
grad_input
=
grad_offset
=
grad_weight
=
None
grad_input
=
grad_offset
=
grad_weight
=
None
...
@@ -81,44 +70,46 @@ class DeformConvFunction(Function):
...
@@ -81,44 +70,46 @@ class DeformConvFunction(Function):
if
not
isinstance
(
grad_output
,
torch
.
cuda
.
FloatTensor
):
if
not
isinstance
(
grad_output
,
torch
.
cuda
.
FloatTensor
):
raise
NotImplementedError
raise
NotImplementedError
cur_im2col_step
=
min
(
self
.
im2col_step
,
input
.
shape
[
0
])
cur_im2col_step
=
min
(
ctx
.
im2col_step
,
input
.
shape
[
0
])
assert
(
input
.
shape
[
0
]
%
assert
(
input
.
shape
[
0
]
%
cur_im2col_step
)
==
0
,
'im2col step must divide batchsize'
cur_im2col_step
)
==
0
,
'im2col step must divide batchsize'
if
self
.
needs_input_grad
[
0
]
or
self
.
needs_input_grad
[
1
]:
if
ctx
.
needs_input_grad
[
0
]
or
ctx
.
needs_input_grad
[
1
]:
grad_input
=
input
.
new
(
*
input
.
size
()).
zero_
(
)
grad_input
=
torch
.
zeros_like
(
input
)
grad_offset
=
offset
.
new
(
*
offset
.
size
()).
zero_
(
)
grad_offset
=
torch
.
zeros_like
(
offset
)
deform_conv
.
deform_conv_backward_input_cuda
(
deform_conv
_cuda
.
deform_conv_backward_input_cuda
(
input
,
offset
,
grad_output
,
grad_input
,
input
,
offset
,
grad_output
,
grad_input
,
grad_offset
,
weight
,
self
.
bufs_
[
0
],
weight
.
size
(
3
),
grad_offset
,
weight
,
ctx
.
bufs_
[
0
],
weight
.
size
(
3
),
weight
.
size
(
2
),
self
.
stride
[
1
],
self
.
stride
[
0
],
weight
.
size
(
2
),
ctx
.
stride
[
1
],
ctx
.
stride
[
0
],
self
.
padding
[
1
],
self
.
padding
[
0
],
self
.
dilation
[
1
],
ctx
.
padding
[
1
],
ctx
.
padding
[
0
],
ctx
.
dilation
[
1
],
self
.
dilation
[
0
],
self
.
deformable_groups
,
cur_im2col_step
)
ctx
.
dilation
[
0
],
ctx
.
deformable_groups
,
cur_im2col_step
)
if
self
.
needs_input_grad
[
2
]:
if
ctx
.
needs_input_grad
[
2
]:
grad_weight
=
weight
.
new
(
*
weight
.
size
()).
zero_
(
)
grad_weight
=
torch
.
zeros_like
(
weight
)
deform_conv
.
deform_conv_backward_parameters_cuda
(
deform_conv
_cuda
.
deform_conv_backward_parameters_cuda
(
input
,
offset
,
grad_output
,
input
,
offset
,
grad_output
,
grad_weight
,
self
.
bufs_
[
0
],
self
.
bufs_
[
1
],
weight
.
size
(
3
),
grad_weight
,
ctx
.
bufs_
[
0
],
ctx
.
bufs_
[
1
],
weight
.
size
(
3
),
weight
.
size
(
2
),
self
.
stride
[
1
],
self
.
stride
[
0
],
weight
.
size
(
2
),
ctx
.
stride
[
1
],
ctx
.
stride
[
0
],
self
.
padding
[
1
],
self
.
padding
[
0
],
self
.
dilation
[
1
],
ctx
.
padding
[
1
],
ctx
.
padding
[
0
],
ctx
.
dilation
[
1
],
self
.
dilation
[
0
],
self
.
deformable_groups
,
1
,
ctx
.
dilation
[
0
],
ctx
.
deformable_groups
,
1
,
cur_im2col_step
)
cur_im2col_step
)
return
grad_input
,
grad_offset
,
grad_weight
return
grad_input
,
grad_offset
,
grad_weight
,
None
,
None
,
None
,
None
def
_output_size
(
self
,
input
,
weight
):
@
staticmethod
def
_output_size
(
input
,
weight
,
padding
,
dilation
,
stride
):
channels
=
weight
.
size
(
0
)
channels
=
weight
.
size
(
0
)
output_size
=
(
input
.
size
(
0
),
channels
)
output_size
=
(
input
.
size
(
0
),
channels
)
for
d
in
range
(
input
.
dim
()
-
2
):
for
d
in
range
(
input
.
dim
()
-
2
):
in_size
=
input
.
size
(
d
+
2
)
in_size
=
input
.
size
(
d
+
2
)
pad
=
self
.
padding
[
d
]
pad
=
padding
[
d
]
kernel
=
self
.
dilation
[
d
]
*
(
weight
.
size
(
d
+
2
)
-
1
)
+
1
kernel
=
dilation
[
d
]
*
(
weight
.
size
(
d
+
2
)
-
1
)
+
1
stride
=
self
.
stride
[
d
]
stride
_
=
stride
[
d
]
output_size
+=
((
in_size
+
(
2
*
pad
)
-
kernel
)
//
stride
+
1
,
)
output_size
+=
((
in_size
+
(
2
*
pad
)
-
kernel
)
//
stride
_
+
1
,
)
if
not
all
(
map
(
lambda
s
:
s
>
0
,
output_size
)):
if
not
all
(
map
(
lambda
s
:
s
>
0
,
output_size
)):
raise
ValueError
(
raise
ValueError
(
"convolution input is too small (output would be {})"
.
format
(
"convolution input is too small (output would be {})"
.
format
(
'x'
.
join
(
map
(
str
,
output_size
))))
'x'
.
join
(
map
(
str
,
output_size
))))
return
output_size
return
output_size
deform_conv
=
DeformConvFunction
.
apply
mmdet/ops/dcn/functions/modulated_dcn_func.py
View file @
6bd60eac
...
@@ -6,128 +6,153 @@ from __future__ import print_function
...
@@ -6,128 +6,153 @@ from __future__ import print_function
import
torch
import
torch
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
..
import
modulated_dcn
as
_backend
from
..
import
modulated_dcn
_cuda
as
_backend
class
ModulatedDeformConvFunction
(
Function
):
class
ModulatedDeformConvFunction
(
Function
):
def
__init__
(
self
,
stride
,
padding
,
dilation
=
1
,
deformable_groups
=
1
):
def
__init__
(
ctx
,
stride
,
padding
,
dilation
=
1
,
deformable_groups
=
1
):
super
(
ModulatedDeformConvFunction
,
self
).
__init__
()
super
(
ModulatedDeformConvFunction
,
ctx
).
__init__
()
self
.
stride
=
stride
ctx
.
stride
=
stride
self
.
padding
=
padding
ctx
.
padding
=
padding
self
.
dilation
=
dilation
ctx
.
dilation
=
dilation
self
.
deformable_groups
=
deformable_groups
ctx
.
deformable_groups
=
deformable_groups
def
forward
(
self
,
input
,
offset
,
mask
,
weight
,
bias
):
@
staticmethod
def
forward
(
ctx
,
input
,
offset
,
mask
,
weight
,
bias
,
stride
,
padding
,
dilation
=
1
,
deformable_groups
=
1
):
ctx
.
stride
=
stride
ctx
.
padding
=
padding
ctx
.
dilation
=
dilation
ctx
.
deformable_groups
=
deformable_groups
if
not
input
.
is_cuda
:
if
not
input
.
is_cuda
:
raise
NotImplementedError
raise
NotImplementedError
if
weight
.
requires_grad
or
mask
.
requires_grad
or
offset
.
requires_grad
\
if
weight
.
requires_grad
or
mask
.
requires_grad
or
offset
.
requires_grad
\
or
input
.
requires_grad
:
or
input
.
requires_grad
:
self
.
save_for_backward
(
input
,
offset
,
mask
,
weight
,
bias
)
ctx
.
save_for_backward
(
input
,
offset
,
mask
,
weight
,
bias
)
output
=
input
.
new
(
*
self
.
_infer_shape
(
input
,
weight
))
output
=
input
.
new
(
self
.
_bufs
=
[
input
.
new
(),
input
.
new
()]
*
ModulatedDeformConvFunction
.
_infer_shape
(
ctx
,
input
,
weight
))
ctx
.
_bufs
=
[
input
.
new
(),
input
.
new
()]
_backend
.
modulated_deform_conv_cuda_forward
(
_backend
.
modulated_deform_conv_cuda_forward
(
input
,
weight
,
bias
,
self
.
_bufs
[
0
],
offset
,
mask
,
output
,
input
,
weight
,
bias
,
ctx
.
_bufs
[
0
],
offset
,
mask
,
output
,
self
.
_bufs
[
1
],
weight
.
shape
[
2
],
weight
.
shape
[
3
],
self
.
stride
,
ctx
.
_bufs
[
1
],
weight
.
shape
[
2
],
weight
.
shape
[
3
],
ctx
.
stride
,
self
.
stride
,
self
.
padding
,
self
.
padding
,
self
.
dilation
,
ctx
.
stride
,
ctx
.
padding
,
ctx
.
padding
,
ctx
.
dilation
,
ctx
.
dilation
,
self
.
dilation
,
self
.
deformable_groups
)
ctx
.
deformable_groups
)
return
output
return
output
def
backward
(
self
,
grad_output
):
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
not
grad_output
.
is_cuda
:
if
not
grad_output
.
is_cuda
:
raise
NotImplementedError
raise
NotImplementedError
input
,
offset
,
mask
,
weight
,
bias
=
self
.
saved_tensors
input
,
offset
,
mask
,
weight
,
bias
=
ctx
.
saved_tensors
grad_input
=
input
.
new
(
*
input
.
size
()).
zero_
(
)
grad_input
=
torch
.
zeros_like
(
input
)
grad_offset
=
offset
.
new
(
*
offset
.
size
()).
zero_
(
)
grad_offset
=
torch
.
zeros_like
(
offset
)
grad_mask
=
mask
.
new
(
*
mask
.
size
()).
zero_
(
)
grad_mask
=
torch
.
zeros_like
(
mask
)
grad_weight
=
weight
.
new
(
*
weight
.
size
()).
zero_
(
)
grad_weight
=
torch
.
zeros_like
(
weight
)
grad_bias
=
bias
.
new
(
*
bias
.
size
()).
zero_
(
)
grad_bias
=
torch
.
zeros_like
(
bias
)
_backend
.
modulated_deform_conv_cuda_backward
(
_backend
.
modulated_deform_conv_cuda_backward
(
input
,
weight
,
bias
,
self
.
_bufs
[
0
],
offset
,
mask
,
self
.
_bufs
[
1
],
input
,
weight
,
bias
,
ctx
.
_bufs
[
0
],
offset
,
mask
,
ctx
.
_bufs
[
1
],
grad_input
,
grad_weight
,
grad_bias
,
grad_offset
,
grad_mask
,
grad_input
,
grad_weight
,
grad_bias
,
grad_offset
,
grad_mask
,
grad_output
,
weight
.
shape
[
2
],
weight
.
shape
[
3
],
self
.
stride
,
grad_output
,
weight
.
shape
[
2
],
weight
.
shape
[
3
],
ctx
.
stride
,
self
.
stride
,
self
.
padding
,
self
.
padding
,
self
.
dilation
,
ctx
.
stride
,
ctx
.
padding
,
ctx
.
padding
,
ctx
.
dilation
,
ctx
.
dilation
,
self
.
dilation
,
self
.
deformable_groups
)
ctx
.
deformable_groups
)
return
grad_input
,
grad_offset
,
grad_mask
,
grad_weight
,
grad_bias
return
(
grad_input
,
grad_offset
,
grad_mask
,
grad_weight
,
grad_bias
,
None
,
None
,
None
,
None
)
def
_infer_shape
(
self
,
input
,
weight
):
@
staticmethod
def
_infer_shape
(
ctx
,
input
,
weight
):
n
=
input
.
size
(
0
)
n
=
input
.
size
(
0
)
channels_out
=
weight
.
size
(
0
)
channels_out
=
weight
.
size
(
0
)
height
,
width
=
input
.
shape
[
2
:
4
]
height
,
width
=
input
.
shape
[
2
:
4
]
kernel_h
,
kernel_w
=
weight
.
shape
[
2
:
4
]
kernel_h
,
kernel_w
=
weight
.
shape
[
2
:
4
]
height_out
=
(
height
+
2
*
self
.
padding
-
height_out
=
(
height
+
2
*
ctx
.
padding
-
(
self
.
dilation
*
(
kernel_h
-
1
)
+
1
))
//
self
.
stride
+
1
(
ctx
.
dilation
*
(
kernel_h
-
1
)
+
1
))
//
ctx
.
stride
+
1
width_out
=
(
width
+
2
*
self
.
padding
-
width_out
=
(
width
+
2
*
ctx
.
padding
-
(
self
.
dilation
*
(
kernel_w
-
1
)
+
1
))
//
self
.
stride
+
1
(
ctx
.
dilation
*
(
kernel_w
-
1
)
+
1
))
//
ctx
.
stride
+
1
return
(
n
,
channels_out
,
height_out
,
width_out
)
return
n
,
channels_out
,
height_out
,
width_out
class
DeformRoIPoolingFunction
(
Function
):
class
DeformRoIPoolingFunction
(
Function
):
def
__init__
(
self
,
@
staticmethod
spatial_scale
,
def
forward
(
ctx
,
pooled_size
,
data
,
output_dim
,
rois
,
no_trans
,
offset
,
group_size
=
1
,
spatial_scale
,
part_size
=
None
,
pooled_size
,
sample_per_part
=
4
,
output_dim
,
trans_std
=
.
0
):
no_trans
,
super
(
DeformRoIPoolingFunction
,
self
).
__init__
()
group_size
=
1
,
self
.
spatial_scale
=
spatial_scale
part_size
=
None
,
self
.
pooled_size
=
pooled_size
sample_per_part
=
4
,
self
.
output_dim
=
output_dim
trans_std
=
.
0
):
self
.
no_trans
=
no_trans
ctx
.
spatial_scale
=
spatial_scale
self
.
group_size
=
group_size
ctx
.
pooled_size
=
pooled_size
self
.
part_size
=
pooled_size
if
part_size
is
None
else
part_size
ctx
.
output_dim
=
output_dim
self
.
sample_per_part
=
sample_per_part
ctx
.
no_trans
=
no_trans
self
.
trans_std
=
trans_std
ctx
.
group_size
=
group_size
ctx
.
part_size
=
pooled_size
if
part_size
is
None
else
part_size
assert
self
.
trans_std
>=
0.0
and
self
.
trans_std
<=
1.0
ctx
.
sample_per_part
=
sample_per_part
ctx
.
trans_std
=
trans_std
def
forward
(
self
,
data
,
rois
,
offset
):
assert
0.0
<=
ctx
.
trans_std
<=
1.0
if
not
data
.
is_cuda
:
if
not
data
.
is_cuda
:
raise
NotImplementedError
raise
NotImplementedError
output
=
data
.
new
(
*
self
.
_infer_shape
(
data
,
rois
))
output
=
data
.
new
(
output_count
=
data
.
new
(
*
self
.
_infer_shape
(
data
,
rois
))
*
DeformRoIPoolingFunction
.
_infer_shape
(
ctx
,
data
,
rois
))
output_count
=
data
.
new
(
*
DeformRoIPoolingFunction
.
_infer_shape
(
ctx
,
data
,
rois
))
_backend
.
deform_psroi_pooling_cuda_forward
(
_backend
.
deform_psroi_pooling_cuda_forward
(
data
,
rois
,
offset
,
output
,
output_count
,
self
.
no_trans
,
data
,
rois
,
offset
,
output
,
output_count
,
ctx
.
no_trans
,
self
.
spatial_scale
,
self
.
output_dim
,
self
.
group_size
,
ctx
.
spatial_scale
,
ctx
.
output_dim
,
ctx
.
group_size
,
ctx
.
pooled_size
,
self
.
pooled_size
,
self
.
part_size
,
self
.
sample_per_part
,
ctx
.
part_size
,
ctx
.
sample_per_part
,
ctx
.
trans_std
)
self
.
trans_std
)
# if data.requires_grad or rois.requires_grad or offset.requires_grad:
# if data.requires_grad or rois.requires_grad or offset.requires_grad:
#
self
.save_for_backward(data, rois, offset, output_count)
#
ctx
.save_for_backward(data, rois, offset, output_count)
self
.
data
=
data
ctx
.
data
=
data
self
.
rois
=
rois
ctx
.
rois
=
rois
self
.
offset
=
offset
ctx
.
offset
=
offset
self
.
output_count
=
output_count
ctx
.
output_count
=
output_count
return
output
return
output
def
backward
(
self
,
grad_output
):
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
not
grad_output
.
is_cuda
:
if
not
grad_output
.
is_cuda
:
raise
NotImplementedError
raise
NotImplementedError
# data, rois, offset, output_count =
self
.saved_tensors
# data, rois, offset, output_count =
ctx
.saved_tensors
data
=
self
.
data
data
=
ctx
.
data
rois
=
self
.
rois
rois
=
ctx
.
rois
offset
=
self
.
offset
offset
=
ctx
.
offset
output_count
=
self
.
output_count
output_count
=
ctx
.
output_count
grad_input
=
data
.
new
(
*
data
.
size
()).
zero_
(
)
grad_input
=
torch
.
zeros_like
(
data
)
grad_offset
=
offset
.
new
(
*
offset
.
size
()).
zero_
(
)
grad_offset
=
torch
.
zeros_like
(
offset
)
_backend
.
deform_psroi_pooling_cuda_backward
(
_backend
.
deform_psroi_pooling_cuda_backward
(
grad_output
,
data
,
rois
,
offset
,
output_count
,
grad_input
,
grad_output
,
data
,
rois
,
offset
,
output_count
,
grad_input
,
grad_offset
,
self
.
no_trans
,
self
.
spatial_scale
,
self
.
output_dim
,
grad_offset
,
ctx
.
no_trans
,
ctx
.
spatial_scale
,
ctx
.
output_dim
,
self
.
group_size
,
self
.
pooled_size
,
self
.
part_size
,
ctx
.
group_size
,
ctx
.
pooled_size
,
ctx
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
ctx
.
sample_per_part
,
ctx
.
trans_std
)
return
grad_input
,
torch
.
zeros
(
rois
.
shape
).
cuda
(
),
grad_offset
return
(
grad_input
,
torch
.
zeros
_like
(
rois
),
grad_offset
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
def
_infer_shape
(
self
,
data
,
rois
):
# _, c, h, w = data.shape[:4]
@
staticmethod
# c = data.shape[1]
def
_infer_shape
(
ctx
,
data
,
rois
):
n
=
rois
.
shape
[
0
]
n
=
rois
.
shape
[
0
]
return
n
,
self
.
output_dim
,
self
.
pooled_size
,
self
.
pooled_size
return
n
,
ctx
.
output_dim
,
ctx
.
pooled_size
,
ctx
.
pooled_size
modulated_deform_conv
=
ModulatedDeformConvFunction
.
apply
deform_roi_pooling
=
DeformRoIPoolingFunction
.
apply
mmdet/ops/dcn/modules/deform_conv.py
View file @
6bd60eac
...
@@ -2,10 +2,11 @@ import math
...
@@ -2,10 +2,11 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv.cnn
import
uniform_init
from
torch.nn.modules.module
import
Module
from
torch.nn.modules.module
import
Module
from
torch.nn.modules.utils
import
_pair
from
torch.nn.modules.utils
import
_pair
from
..functions.deform_conv
import
deform_conv
_function
from
..functions.deform_conv
import
deform_conv
class
DeformConv
(
Module
):
class
DeformConv
(
Module
):
...
@@ -37,9 +38,9 @@ class DeformConv(Module):
...
@@ -37,9 +38,9 @@ class DeformConv(Module):
for
k
in
self
.
kernel_size
:
for
k
in
self
.
kernel_size
:
n
*=
k
n
*=
k
stdv
=
1.
/
math
.
sqrt
(
n
)
stdv
=
1.
/
math
.
sqrt
(
n
)
self
.
weight
.
data
.
uniform_
(
-
stdv
,
stdv
)
uniform_init
(
self
,
-
stdv
,
stdv
)
def
forward
(
self
,
input
,
offset
):
def
forward
(
self
,
input
,
offset
):
return
deform_conv
_function
(
input
,
offset
,
self
.
weight
,
self
.
stride
,
return
deform_conv
(
input
,
offset
,
self
.
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
padding
,
self
.
dilation
,
self
.
num_deformable_groups
)
self
.
num_deformable_groups
)
mmdet/ops/dcn/modules/modulated_dcn.py
View file @
6bd60eac
...
@@ -6,11 +6,12 @@ from __future__ import print_function
...
@@ -6,11 +6,12 @@ from __future__ import print_function
import
math
import
math
import
torch
import
torch
from
mmcv.cnn
import
uniform_init
from
torch
import
nn
from
torch
import
nn
from
torch.nn.modules.utils
import
_pair
from
torch.nn.modules.utils
import
_pair
from
..functions.modulated_dcn_func
import
D
eform
RoIPoolingFunction
from
..functions.modulated_dcn_func
import
d
eform
_roi_pooling
from
..functions.modulated_dcn_func
import
M
odulated
D
eform
ConvFuncti
on
from
..functions.modulated_dcn_func
import
m
odulated
_d
eform
_c
on
v
class
ModulatedDeformConv
(
nn
.
Module
):
class
ModulatedDeformConv
(
nn
.
Module
):
...
@@ -46,13 +47,12 @@ class ModulatedDeformConv(nn.Module):
...
@@ -46,13 +47,12 @@ class ModulatedDeformConv(nn.Module):
for
k
in
self
.
kernel_size
:
for
k
in
self
.
kernel_size
:
n
*=
k
n
*=
k
stdv
=
1.
/
math
.
sqrt
(
n
)
stdv
=
1.
/
math
.
sqrt
(
n
)
self
.
weight
.
data
.
uniform_
(
-
stdv
,
stdv
)
uniform_init
(
self
,
-
stdv
,
stdv
)
self
.
bias
.
data
.
zero_
()
def
forward
(
self
,
input
,
offset
,
mask
):
def
forward
(
self
,
input
,
offset
,
mask
):
func
=
M
odulated
D
eform
C
onv
Function
(
return
m
odulated
_d
eform
_c
onv
(
input
,
offset
,
mask
,
self
.
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
deformable_groups
)
self
.
bias
,
self
.
stride
,
self
.
padding
,
return
func
(
input
,
offset
,
mask
,
self
.
weight
,
self
.
bia
s
)
self
.
dilation
,
self
.
deformable_group
s
)
class
ModulatedDeformConvPack
(
ModulatedDeformConv
):
class
ModulatedDeformConvPack
(
ModulatedDeformConv
):
...
@@ -89,9 +89,9 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
...
@@ -89,9 +89,9 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
o1
,
o2
,
mask
=
torch
.
chunk
(
out
,
3
,
dim
=
1
)
o1
,
o2
,
mask
=
torch
.
chunk
(
out
,
3
,
dim
=
1
)
offset
=
torch
.
cat
((
o1
,
o2
),
dim
=
1
)
offset
=
torch
.
cat
((
o1
,
o2
),
dim
=
1
)
mask
=
torch
.
sigmoid
(
mask
)
mask
=
torch
.
sigmoid
(
mask
)
func
=
M
odulated
D
eform
C
onv
Function
(
return
m
odulated
_d
eform
_c
onv
(
input
,
offset
,
mask
,
self
.
weight
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
deformable_groups
)
self
.
bias
,
self
.
stride
,
self
.
padding
,
return
func
(
input
,
offset
,
mask
,
self
.
weight
,
self
.
bia
s
)
self
.
dilation
,
self
.
deformable_group
s
)
class
DeformRoIPooling
(
nn
.
Module
):
class
DeformRoIPooling
(
nn
.
Module
):
...
@@ -115,16 +115,14 @@ class DeformRoIPooling(nn.Module):
...
@@ -115,16 +115,14 @@ class DeformRoIPooling(nn.Module):
self
.
part_size
=
pooled_size
if
part_size
is
None
else
part_size
self
.
part_size
=
pooled_size
if
part_size
is
None
else
part_size
self
.
sample_per_part
=
sample_per_part
self
.
sample_per_part
=
sample_per_part
self
.
trans_std
=
trans_std
self
.
trans_std
=
trans_std
self
.
func
=
DeformRoIPoolingFunction
(
self
.
spatial_scale
,
self
.
pooled_size
,
self
.
output_dim
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
def
forward
(
self
,
data
,
rois
,
offset
):
def
forward
(
self
,
data
,
rois
,
offset
):
if
self
.
no_trans
:
if
self
.
no_trans
:
offset
=
data
.
new
()
offset
=
data
.
new
()
return
self
.
func
(
data
,
rois
,
offset
)
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
pooled_size
,
self
.
output_dim
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
class
ModulatedDeformRoIPoolingPack
(
DeformRoIPooling
):
class
ModulatedDeformRoIPoolingPack
(
DeformRoIPooling
):
...
@@ -146,10 +144,6 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
...
@@ -146,10 +144,6 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
self
.
deform_fc_dim
=
deform_fc_dim
self
.
deform_fc_dim
=
deform_fc_dim
if
not
no_trans
:
if
not
no_trans
:
self
.
func_offset
=
DeformRoIPoolingFunction
(
self
.
spatial_scale
,
self
.
pooled_size
,
self
.
output_dim
,
True
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
self
.
offset_fc
=
nn
.
Sequential
(
self
.
offset_fc
=
nn
.
Sequential
(
nn
.
Linear
(
nn
.
Linear
(
self
.
pooled_size
*
self
.
pooled_size
*
self
.
output_dim
,
self
.
pooled_size
*
self
.
pooled_size
*
self
.
output_dim
,
...
@@ -176,11 +170,20 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
...
@@ -176,11 +170,20 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
else
:
else
:
n
=
rois
.
shape
[
0
]
n
=
rois
.
shape
[
0
]
offset
=
data
.
new
()
offset
=
data
.
new
()
x
=
self
.
func_offset
(
data
,
rois
,
offset
)
x
=
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
pooled_size
,
self
.
output_dim
,
True
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
offset
=
self
.
offset_fc
(
x
.
view
(
n
,
-
1
))
offset
=
self
.
offset_fc
(
x
.
view
(
n
,
-
1
))
offset
=
offset
.
view
(
n
,
2
,
self
.
pooled_size
,
self
.
pooled_size
)
offset
=
offset
.
view
(
n
,
2
,
self
.
pooled_size
,
self
.
pooled_size
)
mask
=
self
.
mask_fc
(
x
.
view
(
n
,
-
1
))
mask
=
self
.
mask_fc
(
x
.
view
(
n
,
-
1
))
mask
=
mask
.
view
(
n
,
1
,
self
.
pooled_size
,
self
.
pooled_size
)
mask
=
mask
.
view
(
n
,
1
,
self
.
pooled_size
,
self
.
pooled_size
)
feat
=
self
.
func
(
data
,
rois
,
offset
)
*
mask
feat
=
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
pooled_size
,
self
.
output_dim
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
*
mask
return
feat
return
feat
return
self
.
func
(
data
,
rois
,
offset
)
return
deform_roi_pooling
(
data
,
rois
,
offset
,
self
.
spatial_scale
,
self
.
pooled_size
,
self
.
output_dim
,
self
.
no_trans
,
self
.
group_size
,
self
.
part_size
,
self
.
sample_per_part
,
self
.
trans_std
)
mmdet/ops/dcn/setup_modulated.py
View file @
6bd60eac
...
@@ -2,9 +2,9 @@ from setuptools import setup
...
@@ -2,9 +2,9 @@ from setuptools import setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
setup
(
setup
(
name
=
'modulated_d
eform_conv
'
,
name
=
'modulated_d
cn_cuda
'
,
ext_modules
=
[
ext_modules
=
[
CUDAExtension
(
'modulated_dcn'
,
[
CUDAExtension
(
'modulated_dcn
_cuda
'
,
[
'src/modulated_dcn_cuda.cpp'
,
'src/modulated_dcn_cuda.cpp'
,
'src/modulated_deform_im2col_cuda.cu'
,
'src/modulated_deform_im2col_cuda.cu'
,
'src/deform_psroi_pooling_cuda.cu'
'src/deform_psroi_pooling_cuda.cu'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment