Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2b70774e
Unverified
Commit
2b70774e
authored
Feb 24, 2023
by
mpearce25
Committed by
GitHub
Feb 24, 2023
Browse files
Singular Sanitize BoundingBox (#7316)
Co-authored-by:
Nicolas Hug
<
contact@nicolas-hug.com
>
parent
0daffad3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
22 additions
and
22 deletions
+22
-22
gallery/plot_transforms_v2_e2e.py
gallery/plot_transforms_v2_e2e.py
+2
-2
test/test_transforms_v2.py
test/test_transforms_v2.py
+13
-13
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+1
-1
torchvision/transforms/v2/__init__.py
torchvision/transforms/v2/__init__.py
+1
-1
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+2
-2
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+3
-3
No files found.
gallery/plot_transforms_v2_e2e.py
View file @
2b70774e
...
@@ -105,13 +105,13 @@ transform = transforms.Compose(
...
@@ -105,13 +105,13 @@ transform = transforms.Compose(
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToImageTensor
(),
transforms
.
ToImageTensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float32
),
transforms
.
ConvertImageDtype
(
torch
.
float32
),
transforms
.
SanitizeBoundingBox
es
(),
transforms
.
SanitizeBoundingBox
(),
]
]
)
)
########################################################################################################################
########################################################################################################################
# .. note::
# .. note::
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox
es
` transform is a no-op in this example, but it
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox` transform is a no-op in this example, but it
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
# the corresponding labels and optionally masks. It is particularly critical to add it if
# the corresponding labels and optionally masks. It is particularly critical to add it if
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
...
...
test/test_transforms_v2.py
View file @
2b70774e
...
@@ -275,7 +275,7 @@ class TestSmoke:
...
@@ -275,7 +275,7 @@ class TestSmoke:
boxes
=
datapoints
.
BoundingBox
([[
0
,
0
,
0
,
0
]],
format
=
format
,
spatial_size
=
(
224
,
244
)),
boxes
=
datapoints
.
BoundingBox
([[
0
,
0
,
0
,
0
]],
format
=
format
,
spatial_size
=
(
224
,
244
)),
labels
=
torch
.
tensor
([
3
]),
labels
=
torch
.
tensor
([
3
]),
)
)
assert
transforms
.
SanitizeBoundingBox
es
()(
sample
)[
"boxes"
].
shape
==
(
0
,
4
)
assert
transforms
.
SanitizeBoundingBox
()(
sample
)[
"boxes"
].
shape
==
(
0
,
4
)
@
parametrize
(
@
parametrize
(
[
[
...
@@ -1876,7 +1876,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1876,7 +1876,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
transforms
.
ConvertImageDtype
(
torch
.
float
),
transforms
.
ConvertImageDtype
(
torch
.
float
),
]
]
if
sanitize
:
if
sanitize
:
t
+=
[
transforms
.
SanitizeBoundingBox
es
()]
t
+=
[
transforms
.
SanitizeBoundingBox
()]
t
=
transforms
.
Compose
(
t
)
t
=
transforms
.
Compose
(
t
)
num_boxes
=
5
num_boxes
=
5
...
@@ -1917,7 +1917,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -1917,7 +1917,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
# ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
# ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
# doesn't remove them strictly speaking, it just marks some boxes as
# doesn't remove them strictly speaking, it just marks some boxes as
# degenerate and those boxes will be later removed by
# degenerate and those boxes will be later removed by
# SanitizeBoundingBox
es
(), which we add to the pipelines if the sanitize
# SanitizeBoundingBox(), which we add to the pipelines if the sanitize
# param is True.
# param is True.
# Note that the values below are probably specific to the random seed
# Note that the values below are probably specific to the random seed
# set above (which is fine).
# set above (which is fine).
...
@@ -1989,7 +1989,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
...
@@ -1989,7 +1989,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
img
=
sample
.
pop
(
"image"
)
img
=
sample
.
pop
(
"image"
)
sample
=
(
img
,
sample
)
sample
=
(
img
,
sample
)
out
=
transforms
.
SanitizeBoundingBox
es
(
min_size
=
min_size
,
labels_getter
=
labels_getter
)(
sample
)
out
=
transforms
.
SanitizeBoundingBox
(
min_size
=
min_size
,
labels_getter
=
labels_getter
)(
sample
)
if
sample_type
is
tuple
:
if
sample_type
is
tuple
:
out_image
=
out
[
0
]
out_image
=
out
[
0
]
...
@@ -2023,13 +2023,13 @@ def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
...
@@ -2023,13 +2023,13 @@ def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
sample
=
{
key
:
labels
,
"another_key"
:
"whatever"
}
sample
=
{
key
:
labels
,
"another_key"
:
"whatever"
}
if
sample_type
is
tuple
:
if
sample_type
is
tuple
:
sample
=
(
None
,
sample
,
"whatever_again"
)
sample
=
(
None
,
sample
,
"whatever_again"
)
assert
transforms
.
SanitizeBoundingBox
es
.
_find_labels_default_heuristic
(
sample
)
is
labels
assert
transforms
.
SanitizeBoundingBox
.
_find_labels_default_heuristic
(
sample
)
is
labels
if
key
.
lower
()
!=
"labels"
:
if
key
.
lower
()
!=
"labels"
:
# If "labels" is in the dict (case-insensitive),
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
# it takes precedence over other keys which would otherwise be a match
d
=
{
key
:
"something_else"
,
"labels"
:
labels
}
d
=
{
key
:
"something_else"
,
"labels"
:
labels
}
assert
transforms
.
SanitizeBoundingBox
es
.
_find_labels_default_heuristic
(
d
)
is
labels
assert
transforms
.
SanitizeBoundingBox
.
_find_labels_default_heuristic
(
d
)
is
labels
def
test_sanitize_bounding_boxes_errors
():
def
test_sanitize_bounding_boxes_errors
():
...
@@ -2041,25 +2041,25 @@ def test_sanitize_bounding_boxes_errors():
...
@@ -2041,25 +2041,25 @@ def test_sanitize_bounding_boxes_errors():
)
)
with
pytest
.
raises
(
ValueError
,
match
=
"min_size must be >= 1"
):
with
pytest
.
raises
(
ValueError
,
match
=
"min_size must be >= 1"
):
transforms
.
SanitizeBoundingBox
es
(
min_size
=
0
)
transforms
.
SanitizeBoundingBox
(
min_size
=
0
)
with
pytest
.
raises
(
ValueError
,
match
=
"labels_getter should either be a str"
):
with
pytest
.
raises
(
ValueError
,
match
=
"labels_getter should either be a str"
):
transforms
.
SanitizeBoundingBox
es
(
labels_getter
=
12
)
transforms
.
SanitizeBoundingBox
(
labels_getter
=
12
)
with
pytest
.
raises
(
ValueError
,
match
=
"Could not infer where the labels are"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Could not infer where the labels are"
):
bad_labels_key
=
{
"bbox"
:
good_bbox
,
"BAD_KEY"
:
torch
.
arange
(
good_bbox
.
shape
[
0
])}
bad_labels_key
=
{
"bbox"
:
good_bbox
,
"BAD_KEY"
:
torch
.
arange
(
good_bbox
.
shape
[
0
])}
transforms
.
SanitizeBoundingBox
es
()(
bad_labels_key
)
transforms
.
SanitizeBoundingBox
()(
bad_labels_key
)
with
pytest
.
raises
(
ValueError
,
match
=
"If labels_getter is a str or 'default'"
):
with
pytest
.
raises
(
ValueError
,
match
=
"If labels_getter is a str or 'default'"
):
not_a_dict
=
(
good_bbox
,
torch
.
arange
(
good_bbox
.
shape
[
0
]))
not_a_dict
=
(
good_bbox
,
torch
.
arange
(
good_bbox
.
shape
[
0
]))
transforms
.
SanitizeBoundingBox
es
()(
not_a_dict
)
transforms
.
SanitizeBoundingBox
()(
not_a_dict
)
with
pytest
.
raises
(
ValueError
,
match
=
"must be a tensor"
):
with
pytest
.
raises
(
ValueError
,
match
=
"must be a tensor"
):
not_a_tensor
=
{
"bbox"
:
good_bbox
,
"labels"
:
torch
.
arange
(
good_bbox
.
shape
[
0
]).
tolist
()}
not_a_tensor
=
{
"bbox"
:
good_bbox
,
"labels"
:
torch
.
arange
(
good_bbox
.
shape
[
0
]).
tolist
()}
transforms
.
SanitizeBoundingBox
es
()(
not_a_tensor
)
transforms
.
SanitizeBoundingBox
()(
not_a_tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
"Number of boxes"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Number of boxes"
):
different_sizes
=
{
"bbox"
:
good_bbox
,
"labels"
:
torch
.
arange
(
good_bbox
.
shape
[
0
]
+
3
)}
different_sizes
=
{
"bbox"
:
good_bbox
,
"labels"
:
torch
.
arange
(
good_bbox
.
shape
[
0
]
+
3
)}
transforms
.
SanitizeBoundingBox
es
()(
different_sizes
)
transforms
.
SanitizeBoundingBox
()(
different_sizes
)
with
pytest
.
raises
(
ValueError
,
match
=
"boxes must be of shape"
):
with
pytest
.
raises
(
ValueError
,
match
=
"boxes must be of shape"
):
bad_bbox
=
datapoints
.
BoundingBox
(
# batch with 2 elements
bad_bbox
=
datapoints
.
BoundingBox
(
# batch with 2 elements
...
@@ -2071,7 +2071,7 @@ def test_sanitize_bounding_boxes_errors():
...
@@ -2071,7 +2071,7 @@ def test_sanitize_bounding_boxes_errors():
spatial_size
=
(
20
,
20
),
spatial_size
=
(
20
,
20
),
)
)
different_sizes
=
{
"bbox"
:
bad_bbox
,
"labels"
:
torch
.
arange
(
bad_bbox
.
shape
[
0
])}
different_sizes
=
{
"bbox"
:
bad_bbox
,
"labels"
:
torch
.
arange
(
bad_bbox
.
shape
[
0
])}
transforms
.
SanitizeBoundingBox
es
()(
different_sizes
)
transforms
.
SanitizeBoundingBox
()(
different_sizes
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
test/test_transforms_v2_consistency.py
View file @
2b70774e
...
@@ -1099,7 +1099,7 @@ class TestRefDetTransforms:
...
@@ -1099,7 +1099,7 @@ class TestRefDetTransforms:
v2_transforms
.
Compose
(
v2_transforms
.
Compose
(
[
[
v2_transforms
.
RandomIoUCrop
(),
v2_transforms
.
RandomIoUCrop
(),
v2_transforms
.
SanitizeBoundingBox
es
(
labels_getter
=
lambda
sample
:
sample
[
1
][
"labels"
]),
v2_transforms
.
SanitizeBoundingBox
(
labels_getter
=
lambda
sample
:
sample
[
1
][
"labels"
]),
]
]
),
),
{
"with_mask"
:
False
},
{
"with_mask"
:
False
},
...
...
torchvision/transforms/v2/__init__.py
View file @
2b70774e
...
@@ -40,7 +40,7 @@ from ._geometry import (
...
@@ -40,7 +40,7 @@ from ._geometry import (
TenCrop
,
TenCrop
,
)
)
from
._meta
import
ClampBoundingBox
,
ConvertBoundingBoxFormat
,
ConvertDtype
,
ConvertImageDtype
from
._meta
import
ClampBoundingBox
,
ConvertBoundingBoxFormat
,
ConvertDtype
,
ConvertImageDtype
from
._misc
import
GaussianBlur
,
Identity
,
Lambda
,
LinearTransformation
,
Normalize
,
SanitizeBoundingBox
es
,
ToDtype
from
._misc
import
GaussianBlur
,
Identity
,
Lambda
,
LinearTransformation
,
Normalize
,
SanitizeBoundingBox
,
ToDtype
from
._temporal
import
UniformTemporalSubsample
from
._temporal
import
UniformTemporalSubsample
from
._type_conversion
import
PILToTensor
,
ToImagePIL
,
ToImageTensor
,
ToPILImage
from
._type_conversion
import
PILToTensor
,
ToImagePIL
,
ToImageTensor
,
ToPILImage
...
...
torchvision/transforms/v2/_geometry.py
View file @
2b70774e
...
@@ -1114,7 +1114,7 @@ class RandomIoUCrop(Transform):
...
@@ -1114,7 +1114,7 @@ class RandomIoUCrop(Transform):
.. warning::
.. warning::
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBox
es
`, either immediately
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBox`, either immediately
after or later in the transforms pipeline.
after or later in the transforms pipeline.
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
...
@@ -1222,7 +1222,7 @@ class RandomIoUCrop(Transform):
...
@@ -1222,7 +1222,7 @@ class RandomIoUCrop(Transform):
if
isinstance
(
output
,
datapoints
.
BoundingBox
):
if
isinstance
(
output
,
datapoints
.
BoundingBox
):
# We "mark" the invalid boxes as degenreate, and they can be
# We "mark" the invalid boxes as degenreate, and they can be
# removed by a later call to SanitizeBoundingBox
es
()
# removed by a later call to SanitizeBoundingBox()
output
[
~
params
[
"is_within_crop_area"
]]
=
0
output
[
~
params
[
"is_within_crop_area"
]]
=
0
return
output
return
output
...
...
torchvision/transforms/v2/_misc.py
View file @
2b70774e
...
@@ -246,7 +246,7 @@ class ToDtype(Transform):
...
@@ -246,7 +246,7 @@ class ToDtype(Transform):
return
inpt
.
to
(
dtype
=
dtype
)
return
inpt
.
to
(
dtype
=
dtype
)
class
SanitizeBoundingBox
es
(
Transform
):
class
SanitizeBoundingBox
(
Transform
):
# This removes boxes and their corresponding labels:
# This removes boxes and their corresponding labels:
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
...
@@ -269,7 +269,7 @@ class SanitizeBoundingBoxes(Transform):
...
@@ -269,7 +269,7 @@ class SanitizeBoundingBoxes(Transform):
elif
callable
(
labels_getter
):
elif
callable
(
labels_getter
):
self
.
_labels_getter
=
labels_getter
self
.
_labels_getter
=
labels_getter
elif
isinstance
(
labels_getter
,
str
):
elif
isinstance
(
labels_getter
,
str
):
self
.
_labels_getter
=
lambda
inputs
:
SanitizeBoundingBox
es
.
_get_dict_or_second_tuple_entry
(
inputs
)[
self
.
_labels_getter
=
lambda
inputs
:
SanitizeBoundingBox
.
_get_dict_or_second_tuple_entry
(
inputs
)[
labels_getter
# type: ignore[index]
labels_getter
# type: ignore[index]
]
]
elif
labels_getter
is
None
:
elif
labels_getter
is
None
:
...
@@ -300,7 +300,7 @@ class SanitizeBoundingBoxes(Transform):
...
@@ -300,7 +300,7 @@ class SanitizeBoundingBoxes(Transform):
def
_find_labels_default_heuristic
(
inputs
:
Dict
[
str
,
Any
])
->
Optional
[
torch
.
Tensor
]:
def
_find_labels_default_heuristic
(
inputs
:
Dict
[
str
,
Any
])
->
Optional
[
torch
.
Tensor
]:
# Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
# Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
# Returns None if nothing is found
inputs
=
SanitizeBoundingBox
es
.
_get_dict_or_second_tuple_entry
(
inputs
)
inputs
=
SanitizeBoundingBox
.
_get_dict_or_second_tuple_entry
(
inputs
)
candidate_key
=
None
candidate_key
=
None
with
suppress
(
StopIteration
):
with
suppress
(
StopIteration
):
candidate_key
=
next
(
key
for
key
in
inputs
.
keys
()
if
key
.
lower
()
==
"labels"
)
candidate_key
=
next
(
key
for
key
in
inputs
.
keys
()
if
key
.
lower
()
==
"labels"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment