Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2e1e0b63
Unverified
Commit
2e1e0b63
authored
May 19, 2019
by
Francisco Massa
Committed by
GitHub
May 19, 2019
Browse files
Fix RoIAlign and RoIPool for non-contiguous gradients (#920)
parent
12d2c737
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
39 additions
and
4 deletions
+39
-4
test/test_ops.py
test/test_ops.py
+35
-0
torchvision/csrc/cpu/ROIAlign_cpu.cpp
torchvision/csrc/cpu/ROIAlign_cpu.cpp
+1
-1
torchvision/csrc/cpu/ROIPool_cpu.cpp
torchvision/csrc/cpu/ROIPool_cpu.cpp
+1
-1
torchvision/csrc/cuda/ROIAlign_cuda.cu
torchvision/csrc/cuda/ROIAlign_cuda.cu
+1
-1
torchvision/csrc/cuda/ROIPool_cuda.cu
torchvision/csrc/cuda/ROIPool_cuda.cu
+1
-1
No files found.
test/test_ops.py
View file @
2e1e0b63
...
...
@@ -135,6 +135,41 @@ class RoIPoolTester(unittest.TestCase):
assert
torch
.
allclose
(
x
.
grad
,
gt_grad
),
'gradient incorrect for roi_pool'
def
test_roi_pool_align_non_cont_grad_cpu
(
self
):
devices
=
[
'cpu'
]
if
torch
.
cuda
.
is_available
():
devices
.
append
(
'cuda'
)
for
d
in
devices
:
device
=
torch
.
device
(
d
)
rois
=
torch
.
tensor
([
[
0
,
0
,
0
,
9
,
9
],
[
0
,
0
,
5
,
5
,
9
],
[
0
,
5
,
5
,
9
,
9
]],
dtype
=
self
.
dtype
,
device
=
device
)
grad_cont
=
torch
.
rand
(
3
,
1
,
5
,
5
,
dtype
=
self
.
dtype
,
device
=
device
)
grad
=
grad_cont
.
permute
(
2
,
1
,
3
,
0
).
contiguous
().
permute
(
3
,
1
,
0
,
2
)
for
op
in
[
'RoIPool'
,
'RoIAlign'
]:
x
=
torch
.
rand
(
1
,
1
,
10
,
10
,
dtype
=
self
.
dtype
,
device
=
device
,
requires_grad
=
True
)
kwargs
=
{}
if
op
==
'RoIAlign'
:
kwargs
[
'sampling_ratio'
]
=
1
m
=
getattr
(
ops
,
op
)((
5
,
5
),
1
,
**
kwargs
)
y
=
m
(
x
,
rois
)
y
.
backward
(
grad_cont
)
g1
=
x
.
grad
.
detach
().
clone
()
del
x
.
grad
y
=
m
(
x
,
rois
)
y
.
backward
(
grad
)
g2
=
x
.
grad
.
detach
().
clone
()
del
x
.
grad
assert
torch
.
allclose
(
g1
,
g2
),
'gradient incorrect for {}'
.
format
(
op
)
def
test_roi_pool_gradcheck_cpu
(
self
):
device
=
torch
.
device
(
'cpu'
)
x
=
torch
.
rand
(
1
,
1
,
10
,
10
,
dtype
=
self
.
dtype
,
device
=
device
,
requires_grad
=
True
)
...
...
torchvision/csrc/cpu/ROIAlign_cpu.cpp
View file @
2e1e0b63
...
...
@@ -456,7 +456,7 @@ at::Tensor ROIAlign_backward_cpu(
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
grad
.
type
(),
"ROIAlign_forward"
,
[
&
]
{
ROIAlignBackward
<
scalar_t
>
(
grad
.
numel
(),
grad
.
contiguous
().
data
<
scalar_t
>
(),
grad
.
data
<
scalar_t
>
(),
spatial_scale
,
channels
,
height
,
...
...
torchvision/csrc/cpu/ROIPool_cpu.cpp
View file @
2e1e0b63
...
...
@@ -205,7 +205,7 @@ at::Tensor ROIPool_backward_cpu(
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
grad
.
type
(),
"ROIPool_backward"
,
[
&
]
{
RoIPoolBackward
<
scalar_t
>
(
grad
.
contiguous
().
data
<
scalar_t
>
(),
grad
.
data
<
scalar_t
>
(),
argmax
.
data
<
int
>
(),
num_rois
,
channels
,
...
...
torchvision/csrc/cuda/ROIAlign_cuda.cu
View file @
2e1e0b63
...
...
@@ -396,7 +396,7 @@ at::Tensor ROIAlign_backward_cuda(
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
grad
.
type
(),
"ROIAlign_backward"
,
[
&
]
{
RoIAlignBackward
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
grad
.
numel
(),
grad
.
contiguous
().
data
<
scalar_t
>
(),
grad
.
data
<
scalar_t
>
(),
spatial_scale
,
channels
,
height
,
...
...
torchvision/csrc/cuda/ROIPool_cuda.cu
View file @
2e1e0b63
...
...
@@ -221,7 +221,7 @@ at::Tensor ROIPool_backward_cuda(
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
grad
.
type
(),
"ROIPool_backward"
,
[
&
]
{
RoIPoolBackward
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
grad
.
numel
(),
grad
.
contiguous
().
data
<
scalar_t
>
(),
grad
.
data
<
scalar_t
>
(),
argmax
.
contiguous
().
data
<
int
>
(),
num_rois
,
spatial_scale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment