Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
fc838add
Unverified
Commit
fc838add
authored
May 15, 2023
by
Edward Z. Yang
Committed by
GitHub
May 15, 2023
Browse files
Add deterministic, pure-Python roi_align implementation (#7587)
Signed-off-by:
Edward Z. Yang
<
ezyang@meta.com
>
parent
a5579189
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
227 additions
and
13 deletions
+227
-13
test/test_ops.py
test/test_ops.py
+50
-11
torchvision/ops/roi_align.py
torchvision/ops/roi_align.py
+177
-2
No files found.
test/test_ops.py
View file @
fc838add
...
...
@@ -19,6 +19,22 @@ from torchvision import models, ops
from
torchvision.models.feature_extraction
import
get_graph_node_names
# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class
DeterministicGuard
:
def
__init__
(
self
,
deterministic
,
*
,
warn_only
=
False
):
self
.
deterministic
=
deterministic
self
.
warn_only
=
warn_only
def
__enter__
(
self
):
self
.
deterministic_restore
=
torch
.
are_deterministic_algorithms_enabled
()
self
.
warn_only_restore
=
torch
.
is_deterministic_algorithms_warn_only_enabled
()
torch
.
use_deterministic_algorithms
(
self
.
deterministic
,
warn_only
=
self
.
warn_only
)
def
__exit__
(
self
,
exception_type
,
exception_value
,
traceback
):
torch
.
use_deterministic_algorithms
(
self
.
deterministic_restore
,
warn_only
=
self
.
warn_only_restore
)
class
RoIOpTesterModuleWrapper
(
nn
.
Module
):
def
__init__
(
self
,
obj
):
super
().
__init__
()
...
...
@@ -83,7 +99,7 @@ class RoIOpTester(ABC):
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"contiguous"
,
(
True
,
False
))
def
test_forward
(
self
,
device
,
contiguous
,
x_dtype
=
None
,
rois_dtype
=
None
,
**
kwargs
):
def
test_forward
(
self
,
device
,
contiguous
,
x_dtype
=
None
,
rois_dtype
=
None
,
deterministic
=
False
,
**
kwargs
):
x_dtype
=
self
.
dtype
if
x_dtype
is
None
else
x_dtype
rois_dtype
=
self
.
dtype
if
rois_dtype
is
None
else
rois_dtype
pool_size
=
5
...
...
@@ -99,6 +115,7 @@ class RoIOpTester(ABC):
)
pool_h
,
pool_w
=
pool_size
,
pool_size
with
DeterministicGuard
(
deterministic
):
y
=
self
.
fn
(
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
**
kwargs
)
# the following should be true whether we're running an autocast test or not.
assert
y
.
dtype
==
x
.
dtype
...
...
@@ -140,7 +157,7 @@ class RoIOpTester(ABC):
@
pytest
.
mark
.
parametrize
(
"seed"
,
range
(
10
))
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"contiguous"
,
(
True
,
False
))
def
test_backward
(
self
,
seed
,
device
,
contiguous
):
def
test_backward
(
self
,
seed
,
device
,
contiguous
,
deterministic
=
False
):
torch
.
random
.
manual_seed
(
seed
)
pool_size
=
2
x
=
torch
.
rand
(
1
,
2
*
(
pool_size
**
2
),
5
,
5
,
dtype
=
self
.
dtype
,
device
=
device
,
requires_grad
=
True
)
...
...
@@ -155,7 +172,9 @@ class RoIOpTester(ABC):
script_func
=
self
.
get_script_fn
(
rois
,
pool_size
)
with
DeterministicGuard
(
deterministic
):
gradcheck
(
func
,
(
x
,))
gradcheck
(
script_func
,
(
x
,))
@
needs_cuda
...
...
@@ -384,7 +403,6 @@ class TestRoIAlign(RoIOpTester):
grid_w
=
sampling_ratio
if
sampling_ratio
>
0
else
int
(
np
.
ceil
(
bin_w
))
for
channel
in
range
(
0
,
n_channels
):
val
=
0
for
iy
in
range
(
0
,
grid_h
):
y
=
start_h
+
(
iy
+
0.5
)
*
bin_h
/
grid_h
...
...
@@ -402,21 +420,44 @@ class TestRoIAlign(RoIOpTester):
@
pytest
.
mark
.
parametrize
(
"aligned"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"contiguous"
,
(
True
,
False
))
def
test_forward
(
self
,
device
,
contiguous
,
aligned
,
x_dtype
=
None
,
rois_dtype
=
None
):
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
(
True
,
False
))
def
test_forward
(
self
,
device
,
contiguous
,
deterministic
,
aligned
,
x_dtype
=
None
,
rois_dtype
=
None
):
if
deterministic
and
device
==
"cpu"
:
pytest
.
skip
(
"cpu is always deterministic, don't retest"
)
super
().
test_forward
(
device
=
device
,
contiguous
=
contiguous
,
x_dtype
=
x_dtype
,
rois_dtype
=
rois_dtype
,
aligned
=
aligned
device
=
device
,
contiguous
=
contiguous
,
deterministic
=
deterministic
,
x_dtype
=
x_dtype
,
rois_dtype
=
rois_dtype
,
aligned
=
aligned
,
)
@
needs_cuda
@
pytest
.
mark
.
parametrize
(
"aligned"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
(
torch
.
float
,
torch
.
half
))
@
pytest
.
mark
.
parametrize
(
"rois_dtype"
,
(
torch
.
float
,
torch
.
half
))
def
test_autocast
(
self
,
aligned
,
x_dtype
,
rois_dtype
):
def
test_autocast
(
self
,
aligned
,
deterministic
,
x_dtype
,
rois_dtype
):
with
torch
.
cuda
.
amp
.
autocast
():
self
.
test_forward
(
torch
.
device
(
"cuda"
),
contiguous
=
False
,
aligned
=
aligned
,
x_dtype
=
x_dtype
,
rois_dtype
=
rois_dtype
torch
.
device
(
"cuda"
),
contiguous
=
False
,
deterministic
=
deterministic
,
aligned
=
aligned
,
x_dtype
=
x_dtype
,
rois_dtype
=
rois_dtype
,
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
range
(
10
))
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"contiguous"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
(
True
,
False
))
def
test_backward
(
self
,
seed
,
device
,
contiguous
,
deterministic
):
if
deterministic
and
device
==
"cpu"
:
pytest
.
skip
(
"cpu is always deterministic, don't retest"
)
super
().
test_backward
(
seed
,
device
,
contiguous
,
deterministic
)
def
_make_rois
(
self
,
img_size
,
num_imgs
,
dtype
,
num_rois
=
1000
):
rois
=
torch
.
randint
(
0
,
img_size
//
2
,
size
=
(
num_rois
,
5
)).
to
(
dtype
)
rois
[:,
0
]
=
torch
.
randint
(
0
,
num_imgs
,
size
=
(
num_rois
,))
# set batch index
...
...
@@ -978,7 +1019,6 @@ class TestDeformConv:
weight
=
init_weight
for
d
in
[
"cpu"
,
"cuda"
]:
out
=
ops
.
deform_conv2d
(
img
.
to
(
d
),
offset
.
to
(
d
),
weight
.
to
(
d
),
padding
=
1
,
mask
=
mask
.
to
(
d
))
out
.
mean
().
backward
()
if
true_cpu_grads
is
None
:
...
...
@@ -1374,7 +1414,6 @@ class TestGeneralizedBoxIouLoss:
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
half
])
def
test_giou_loss
(
self
,
dtype
,
device
):
box1
,
box2
,
box3
,
box4
,
box1s
,
box2s
=
get_boxes
(
dtype
,
device
)
# Identical boxes should have loss of 0
...
...
torchvision/ops/roi_align.py
View file @
fc838add
from
typing
import
List
,
Union
import
torch
import
torch._dynamo
import
torch.fx
from
torch
import
nn
,
Tensor
from
torch.jit.annotations
import
BroadcastingList2
from
torch.nn.modules.utils
import
_pair
from
torchvision.extension
import
_assert_has_ops
from
torchvision.extension
import
_assert_has_ops
,
_has_ops
from
..utils
import
_log_api_usage_once
from
._utils
import
check_roi_boxes_shape
,
convert_boxes_to_roi_format
# NB: all inputs are tensors
def
_bilinear_interpolate
(
input
,
# [N, C, H, W]
roi_batch_ind
,
# [K]
y
,
# [K, PH, IY]
x
,
# [K, PW, IX]
ymask
,
# [K, IY]
xmask
,
# [K, IX]
):
_
,
channels
,
height
,
width
=
input
.
size
()
# deal with inverse element out of feature map boundary
y
=
y
.
clamp
(
min
=
0
)
x
=
x
.
clamp
(
min
=
0
)
y_low
=
y
.
int
()
x_low
=
x
.
int
()
y_high
=
torch
.
where
(
y_low
>=
height
-
1
,
height
-
1
,
y_low
+
1
)
y_low
=
torch
.
where
(
y_low
>=
height
-
1
,
height
-
1
,
y_low
)
y
=
torch
.
where
(
y_low
>=
height
-
1
,
y
.
to
(
input
.
dtype
),
y
)
x_high
=
torch
.
where
(
x_low
>=
width
-
1
,
width
-
1
,
x_low
+
1
)
x_low
=
torch
.
where
(
x_low
>=
width
-
1
,
width
-
1
,
x_low
)
x
=
torch
.
where
(
x_low
>=
width
-
1
,
x
.
to
(
input
.
dtype
),
x
)
ly
=
y
-
y_low
lx
=
x
-
x_low
hy
=
1.0
-
ly
hx
=
1.0
-
lx
# do bilinear interpolation, but respect the masking!
# TODO: It's possible the masking here is unnecessary if y and
# x were clamped appropriately; hard to tell
def
masked_index
(
y
,
# [K, PH, IY]
x
,
# [K, PW, IX]
):
if
ymask
is
not
None
:
assert
xmask
is
not
None
y
=
torch
.
where
(
ymask
[:,
None
,
:],
y
,
0
)
x
=
torch
.
where
(
xmask
[:,
None
,
:],
x
,
0
)
return
input
[
roi_batch_ind
[:,
None
,
None
,
None
,
None
,
None
],
torch
.
arange
(
channels
,
device
=
input
.
device
)[
None
,
:,
None
,
None
,
None
,
None
],
y
[:,
None
,
:,
None
,
:,
None
],
# prev [K, PH, IY]
x
[:,
None
,
None
,
:,
None
,
:],
# prev [K, PW, IX]
]
# [K, C, PH, PW, IY, IX]
v1
=
masked_index
(
y_low
,
x_low
)
v2
=
masked_index
(
y_low
,
x_high
)
v3
=
masked_index
(
y_high
,
x_low
)
v4
=
masked_index
(
y_high
,
x_high
)
# all ws preemptively [K, C, PH, PW, IY, IX]
def
outer_prod
(
y
,
x
):
return
y
[:,
None
,
:,
None
,
:,
None
]
*
x
[:,
None
,
None
,
:,
None
,
:]
w1
=
outer_prod
(
hy
,
hx
)
w2
=
outer_prod
(
hy
,
lx
)
w3
=
outer_prod
(
ly
,
hx
)
w4
=
outer_prod
(
ly
,
lx
)
val
=
w1
*
v1
+
w2
*
v2
+
w3
*
v3
+
w4
*
v4
return
val
# TODO: this doesn't actually cache
# TODO: main library should make this easier to do
def
maybe_cast
(
tensor
):
if
torch
.
is_autocast_enabled
()
and
tensor
.
is_cuda
and
tensor
.
dtype
!=
torch
.
double
:
return
tensor
.
float
()
else
:
return
tensor
# This is a slow but pure Python and differentiable implementation of
# roi_align. It potentially is a good basis for Inductor compilation
# (but I have not benchmarked it) but today it is solely used for the
# fact that its backwards can be implemented deterministically,
# which is needed for the PT2 benchmark suite.
#
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@
torch
.
_dynamo
.
allow_in_graph
def
_roi_align
(
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
orig_dtype
=
input
.
dtype
input
=
maybe_cast
(
input
)
rois
=
maybe_cast
(
rois
)
_
,
_
,
height
,
width
=
input
.
size
()
ph
=
torch
.
arange
(
pooled_height
,
device
=
input
.
device
)
# [PH]
pw
=
torch
.
arange
(
pooled_width
,
device
=
input
.
device
)
# [PW]
# input: [N, C, H, W]
# rois: [K, 5]
roi_batch_ind
=
rois
[:,
0
].
int
()
# [K]
offset
=
0.5
if
aligned
else
0.0
roi_start_w
=
rois
[:,
1
]
*
spatial_scale
-
offset
# [K]
roi_start_h
=
rois
[:,
2
]
*
spatial_scale
-
offset
# [K]
roi_end_w
=
rois
[:,
3
]
*
spatial_scale
-
offset
# [K]
roi_end_h
=
rois
[:,
4
]
*
spatial_scale
-
offset
# [K]
roi_width
=
roi_end_w
-
roi_start_w
# [K]
roi_height
=
roi_end_h
-
roi_start_h
# [K]
if
not
aligned
:
roi_width
=
torch
.
clamp
(
roi_width
,
min
=
1.0
)
# [K]
roi_height
=
torch
.
clamp
(
roi_height
,
min
=
1.0
)
# [K]
bin_size_h
=
roi_height
/
pooled_height
# [K]
bin_size_w
=
roi_width
/
pooled_width
# [K]
exact_sampling
=
sampling_ratio
>
0
roi_bin_grid_h
=
sampling_ratio
if
exact_sampling
else
torch
.
ceil
(
roi_height
/
pooled_height
)
# scalar or [K]
roi_bin_grid_w
=
sampling_ratio
if
exact_sampling
else
torch
.
ceil
(
roi_width
/
pooled_width
)
# scalar or [K]
"""
iy, ix = dims(2)
"""
if
exact_sampling
:
count
=
max
(
roi_bin_grid_h
*
roi_bin_grid_w
,
1
)
# scalar
iy
=
torch
.
arange
(
roi_bin_grid_h
,
device
=
input
.
device
)
# [IY]
ix
=
torch
.
arange
(
roi_bin_grid_w
,
device
=
input
.
device
)
# [IX]
ymask
=
None
xmask
=
None
else
:
count
=
torch
.
clamp
(
roi_bin_grid_h
*
roi_bin_grid_w
,
min
=
1
)
# [K]
# When doing adaptive sampling, the number of samples we need to do
# is data-dependent based on how big the ROIs are. This is a bit
# awkward because first-class dims can't actually handle this.
# So instead, we inefficiently suppose that we needed to sample ALL
# the points and mask out things that turned out to be unnecessary
iy
=
torch
.
arange
(
height
,
device
=
input
.
device
)
# [IY]
ix
=
torch
.
arange
(
width
,
device
=
input
.
device
)
# [IX]
ymask
=
iy
[
None
,
:]
<
roi_bin_grid_h
[:,
None
]
# [K, IY]
xmask
=
ix
[
None
,
:]
<
roi_bin_grid_w
[:,
None
]
# [K, IX]
def
from_K
(
t
):
return
t
[:,
None
,
None
]
y
=
(
from_K
(
roi_start_h
)
+
ph
[
None
,
:,
None
]
*
from_K
(
bin_size_h
)
+
(
iy
[
None
,
None
,
:]
+
0.5
)
*
from_K
(
bin_size_h
/
roi_bin_grid_h
)
)
# [K, PH, IY]
x
=
(
from_K
(
roi_start_w
)
+
pw
[
None
,
:,
None
]
*
from_K
(
bin_size_w
)
+
(
ix
[
None
,
None
,
:]
+
0.5
)
*
from_K
(
bin_size_w
/
roi_bin_grid_w
)
)
# [K, PW, IX]
val
=
_bilinear_interpolate
(
input
,
roi_batch_ind
,
y
,
x
,
ymask
,
xmask
)
# [K, C, PH, PW, IY, IX]
# Mask out samples that weren't actually adaptively needed
if
not
exact_sampling
:
val
=
torch
.
where
(
ymask
[:,
None
,
None
,
None
,
:,
None
],
val
,
0
)
val
=
torch
.
where
(
xmask
[:,
None
,
None
,
None
,
None
,
:],
val
,
0
)
output
=
val
.
sum
((
-
1
,
-
2
))
# remove IY, IX ~> [K, C, PH, PW]
if
isinstance
(
count
,
torch
.
Tensor
):
output
/=
count
[:,
None
,
None
,
None
]
else
:
output
/=
count
output
=
output
.
to
(
orig_dtype
)
return
output
@
torch
.
fx
.
wrap
def
roi_align
(
input
:
Tensor
,
...
...
@@ -54,12 +226,15 @@ def roi_align(
"""
if
not
torch
.
jit
.
is_scripting
()
and
not
torch
.
jit
.
is_tracing
():
_log_api_usage_once
(
roi_align
)
_assert_has_ops
()
check_roi_boxes_shape
(
boxes
)
rois
=
boxes
output_size
=
_pair
(
output_size
)
if
not
isinstance
(
rois
,
torch
.
Tensor
):
rois
=
convert_boxes_to_roi_format
(
rois
)
if
not
torch
.
jit
.
is_scripting
():
if
not
_has_ops
()
or
(
torch
.
are_deterministic_algorithms_enabled
()
and
input
.
is_cuda
):
return
_roi_align
(
input
,
rois
,
spatial_scale
,
output_size
[
0
],
output_size
[
1
],
sampling_ratio
,
aligned
)
_assert_has_ops
()
return
torch
.
ops
.
torchvision
.
roi_align
(
input
,
rois
,
spatial_scale
,
output_size
[
0
],
output_size
[
1
],
sampling_ratio
,
aligned
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment