Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
350fdd7a
Unverified
Commit
350fdd7a
authored
Jul 15, 2019
by
Kai Chen
Committed by
GitHub
Jul 15, 2019
Browse files
support torchvision RoIPool and RoIAlign (#990)
parent
c101398c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
32 deletions
+36
-32
mmdet/ops/roi_align/functions/roi_align.py
mmdet/ops/roi_align/functions/roi_align.py
+3
-11
mmdet/ops/roi_align/modules/roi_align.py
mmdet/ops/roi_align/modules/roi_align.py
+18
-6
mmdet/ops/roi_pool/functions/roi_pool.py
mmdet/ops/roi_pool/functions/roi_pool.py
+3
-11
mmdet/ops/roi_pool/modules/roi_pool.py
mmdet/ops/roi_pool/modules/roi_pool.py
+12
-4
No files found.
mmdet/ops/roi_align/functions/roi_align.py
View file @
350fdd7a
from
torch.autograd
import
Function
from
torch.nn.modules.utils
import
_pair
from
..
import
roi_align_cuda
...
...
@@ -7,17 +8,8 @@ class RoIAlignFunction(Function):
@
staticmethod
def
forward
(
ctx
,
features
,
rois
,
out_size
,
spatial_scale
,
sample_num
=
0
):
if
isinstance
(
out_size
,
int
):
out_h
=
out_size
out_w
=
out_size
elif
isinstance
(
out_size
,
tuple
):
assert
len
(
out_size
)
==
2
assert
isinstance
(
out_size
[
0
],
int
)
assert
isinstance
(
out_size
[
1
],
int
)
out_h
,
out_w
=
out_size
else
:
raise
TypeError
(
'"out_size" must be an integer or tuple of integers'
)
out_h
,
out_w
=
_pair
(
out_size
)
assert
isinstance
(
out_h
,
int
)
and
isinstance
(
out_w
,
int
)
ctx
.
spatial_scale
=
spatial_scale
ctx
.
sample_num
=
sample_num
ctx
.
save_for_backward
(
rois
)
...
...
mmdet/ops/roi_align/modules/roi_align.py
View file @
350fdd7a
from
torch.nn
.modules.module
import
Module
from
..functions.roi_align
import
RoIAlignFunction
import
torch.nn
as
nn
from
torch.nn.modules.utils
import
_pair
from
..functions.roi_align
import
roi_align
class
RoIAlign
(
Module
):
def
__init__
(
self
,
out_size
,
spatial_scale
,
sample_num
=
0
):
class
RoIAlign
(
nn
.
Module
):
def
__init__
(
self
,
out_size
,
spatial_scale
,
sample_num
=
0
,
use_torchvision
=
False
):
super
(
RoIAlign
,
self
).
__init__
()
self
.
out_size
=
out_size
self
.
spatial_scale
=
float
(
spatial_scale
)
self
.
sample_num
=
int
(
sample_num
)
self
.
use_torchvision
=
use_torchvision
def
forward
(
self
,
features
,
rois
):
return
RoIAlignFunction
.
apply
(
features
,
rois
,
self
.
out_size
,
if
self
.
use_torchvision
:
from
torchvision.ops
import
roi_align
as
tv_roi_align
return
tv_roi_align
(
features
,
rois
,
_pair
(
self
.
out_size
),
self
.
spatial_scale
,
self
.
sample_num
)
else
:
return
roi_align
(
features
,
rois
,
self
.
out_size
,
self
.
spatial_scale
,
self
.
sample_num
)
mmdet/ops/roi_pool/functions/roi_pool.py
View file @
350fdd7a
import
torch
from
torch.autograd
import
Function
from
torch.nn.modules.utils
import
_pair
from
..
import
roi_pool_cuda
...
...
@@ -8,18 +9,9 @@ class RoIPoolFunction(Function):
@
staticmethod
def
forward
(
ctx
,
features
,
rois
,
out_size
,
spatial_scale
):
if
isinstance
(
out_size
,
int
):
out_h
=
out_size
out_w
=
out_size
elif
isinstance
(
out_size
,
tuple
):
assert
len
(
out_size
)
==
2
assert
isinstance
(
out_size
[
0
],
int
)
assert
isinstance
(
out_size
[
1
],
int
)
out_h
,
out_w
=
out_size
else
:
raise
TypeError
(
'"out_size" must be an integer or tuple of integers'
)
assert
features
.
is_cuda
out_h
,
out_w
=
_pair
(
out_size
)
assert
isinstance
(
out_h
,
int
)
and
isinstance
(
out_w
,
int
)
ctx
.
save_for_backward
(
rois
)
num_channels
=
features
.
size
(
1
)
num_rois
=
rois
.
size
(
0
)
...
...
mmdet/ops/roi_pool/modules/roi_pool.py
View file @
350fdd7a
from
torch.nn.modules.module
import
Module
import
torch.nn
as
nn
from
torch.nn.modules.utils
import
_pair
from
..functions.roi_pool
import
roi_pool
class
RoIPool
(
Module
):
class
RoIPool
(
nn
.
Module
):
def
__init__
(
self
,
out_size
,
spatial_scale
):
def
__init__
(
self
,
out_size
,
spatial_scale
,
use_torchvision
=
False
):
super
(
RoIPool
,
self
).
__init__
()
self
.
out_size
=
out_size
self
.
spatial_scale
=
float
(
spatial_scale
)
self
.
use_torchvision
=
use_torchvision
def
forward
(
self
,
features
,
rois
):
if
self
.
use_torchvision
:
from
torchvision.ops
import
roi_pool
as
tv_roi_pool
return
tv_roi_pool
(
features
,
rois
,
_pair
(
self
.
out_size
),
self
.
spatial_scale
)
else
:
return
roi_pool
(
features
,
rois
,
self
.
out_size
,
self
.
spatial_scale
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment