Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
10239873
Unverified
Commit
10239873
authored
Jun 04, 2024
by
Antoine Broyelle
Committed by
GitHub
Jun 04, 2024
Browse files
Add min_area to `SanitizeBoundingBox` (#7735)
Co-authored-by:
Nicolas Hug
<
contact@nicolas-hug.com
>
parent
f7d9e75b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
18 deletions
+36
-18
test/test_transforms_v2.py
test/test_transforms_v2.py
+15
-12
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+10
-2
torchvision/transforms/v2/functional/_misc.py
torchvision/transforms/v2/functional/_misc.py
+11
-4
No files found.
test/test_transforms_v2.py
View file @
10239873
...
@@ -5805,7 +5805,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
...
@@ -5805,7 +5805,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
class
TestSanitizeBoundingBoxes
:
class
TestSanitizeBoundingBoxes
:
def
_get_boxes_and_valid_mask
(
self
,
H
=
256
,
W
=
128
,
min_size
=
10
):
def
_get_boxes_and_valid_mask
(
self
,
H
=
256
,
W
=
128
,
min_size
=
10
,
min_area
=
10
):
boxes_and_validity
=
[
boxes_and_validity
=
[
([
0
,
1
,
10
,
1
],
False
),
# Y1 == Y2
([
0
,
1
,
10
,
1
],
False
),
# Y1 == Y2
([
0
,
1
,
0
,
20
],
False
),
# X1 == X2
([
0
,
1
,
0
,
20
],
False
),
# X1 == X2
...
@@ -5816,17 +5816,16 @@ class TestSanitizeBoundingBoxes:
...
@@ -5816,17 +5816,16 @@ class TestSanitizeBoundingBoxes:
([
-
1
,
1
,
10
,
20
],
False
),
# any < 0
([
-
1
,
1
,
10
,
20
],
False
),
# any < 0
([
0
,
0
,
-
1
,
20
],
False
),
# any < 0
([
0
,
0
,
-
1
,
20
],
False
),
# any < 0
([
0
,
0
,
-
10
,
-
1
],
False
),
# any < 0
([
0
,
0
,
-
10
,
-
1
],
False
),
# any < 0
([
0
,
0
,
min_size
,
10
],
True
),
# H < min_size
([
0
,
0
,
min_size
,
10
],
min_size
*
10
>=
min_area
),
# H < min_size
([
0
,
0
,
10
,
min_size
],
True
),
# W < min_size
([
0
,
0
,
10
,
min_size
],
min_size
*
10
>=
min_area
),
# W < min_size
([
0
,
0
,
W
,
H
],
True
),
# TODO: Is that actually OK?? Should it be -1?
([
0
,
0
,
W
,
H
],
W
*
H
>=
min_area
),
([
1
,
1
,
30
,
20
],
True
),
([
1
,
1
,
30
,
20
],
29
*
19
>=
min_area
),
([
0
,
0
,
10
,
10
],
True
),
([
0
,
0
,
10
,
10
],
9
*
9
>=
min_area
),
([
1
,
1
,
30
,
20
],
True
),
([
1
,
1
,
30
,
20
],
29
*
19
>=
min_area
),
]
]
random
.
shuffle
(
boxes_and_validity
)
# For test robustness: mix order of wrong and correct cases
random
.
shuffle
(
boxes_and_validity
)
# For test robustness: mix order of wrong and correct cases
boxes
,
expected_valid_mask
=
zip
(
*
boxes_and_validity
)
boxes
,
expected_valid_mask
=
zip
(
*
boxes_and_validity
)
boxes
=
tv_tensors
.
BoundingBoxes
(
boxes
=
tv_tensors
.
BoundingBoxes
(
boxes
,
boxes
,
format
=
tv_tensors
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensors
.
BoundingBoxFormat
.
XYXY
,
...
@@ -5835,7 +5834,7 @@ class TestSanitizeBoundingBoxes:
...
@@ -5835,7 +5834,7 @@ class TestSanitizeBoundingBoxes:
return
boxes
,
expected_valid_mask
return
boxes
,
expected_valid_mask
@
pytest
.
mark
.
parametrize
(
"min_size"
,
(
1
,
1
0
))
@
pytest
.
mark
.
parametrize
(
"min_size
, min_area
"
,
(
(
1
,
1
),
(
10
,
1
),
(
10
,
101
)
))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"labels_getter"
,
"labels_getter"
,
(
(
...
@@ -5848,7 +5847,7 @@ class TestSanitizeBoundingBoxes:
...
@@ -5848,7 +5847,7 @@ class TestSanitizeBoundingBoxes:
),
),
)
)
@
pytest
.
mark
.
parametrize
(
"sample_type"
,
(
tuple
,
dict
))
@
pytest
.
mark
.
parametrize
(
"sample_type"
,
(
tuple
,
dict
))
def
test_transform
(
self
,
min_size
,
labels_getter
,
sample_type
):
def
test_transform
(
self
,
min_size
,
min_area
,
labels_getter
,
sample_type
):
if
sample_type
is
tuple
and
not
isinstance
(
labels_getter
,
str
):
if
sample_type
is
tuple
and
not
isinstance
(
labels_getter
,
str
):
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
...
@@ -5856,7 +5855,7 @@ class TestSanitizeBoundingBoxes:
...
@@ -5856,7 +5855,7 @@ class TestSanitizeBoundingBoxes:
return
return
H
,
W
=
256
,
128
H
,
W
=
256
,
128
boxes
,
expected_valid_mask
=
self
.
_get_boxes_and_valid_mask
(
H
=
H
,
W
=
W
,
min_size
=
min_size
)
boxes
,
expected_valid_mask
=
self
.
_get_boxes_and_valid_mask
(
H
=
H
,
W
=
W
,
min_size
=
min_size
,
min_area
=
min_area
)
valid_indices
=
[
i
for
(
i
,
is_valid
)
in
enumerate
(
expected_valid_mask
)
if
is_valid
]
valid_indices
=
[
i
for
(
i
,
is_valid
)
in
enumerate
(
expected_valid_mask
)
if
is_valid
]
labels
=
torch
.
arange
(
boxes
.
shape
[
0
])
labels
=
torch
.
arange
(
boxes
.
shape
[
0
])
...
@@ -5880,7 +5879,9 @@ class TestSanitizeBoundingBoxes:
...
@@ -5880,7 +5879,9 @@ class TestSanitizeBoundingBoxes:
img
=
sample
.
pop
(
"image"
)
img
=
sample
.
pop
(
"image"
)
sample
=
(
img
,
sample
)
sample
=
(
img
,
sample
)
out
=
transforms
.
SanitizeBoundingBoxes
(
min_size
=
min_size
,
labels_getter
=
labels_getter
)(
sample
)
out
=
transforms
.
SanitizeBoundingBoxes
(
min_size
=
min_size
,
min_area
=
min_area
,
labels_getter
=
labels_getter
)(
sample
)
if
sample_type
is
tuple
:
if
sample_type
is
tuple
:
out_image
=
out
[
0
]
out_image
=
out
[
0
]
...
@@ -5977,6 +5978,8 @@ class TestSanitizeBoundingBoxes:
...
@@ -5977,6 +5978,8 @@ class TestSanitizeBoundingBoxes:
with
pytest
.
raises
(
ValueError
,
match
=
"min_size must be >= 1"
):
with
pytest
.
raises
(
ValueError
,
match
=
"min_size must be >= 1"
):
transforms
.
SanitizeBoundingBoxes
(
min_size
=
0
)
transforms
.
SanitizeBoundingBoxes
(
min_size
=
0
)
with
pytest
.
raises
(
ValueError
,
match
=
"min_area must be >= 1"
):
transforms
.
SanitizeBoundingBoxes
(
min_area
=
0
)
with
pytest
.
raises
(
ValueError
,
match
=
"labels_getter should either be 'default'"
):
with
pytest
.
raises
(
ValueError
,
match
=
"labels_getter should either be 'default'"
):
transforms
.
SanitizeBoundingBoxes
(
labels_getter
=
12
)
transforms
.
SanitizeBoundingBoxes
(
labels_getter
=
12
)
...
...
torchvision/transforms/v2/_misc.py
View file @
10239873
...
@@ -344,7 +344,7 @@ class SanitizeBoundingBoxes(Transform):
...
@@ -344,7 +344,7 @@ class SanitizeBoundingBoxes(Transform):
This transform removes bounding boxes and their associated labels/masks that:
This transform removes bounding boxes and their associated labels/masks that:
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- are below a given ``min_size``
or ``min_area``
: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
- have any coordinate outside of their corresponding image. You may want to
call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals.
call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals.
...
@@ -359,7 +359,8 @@ class SanitizeBoundingBoxes(Transform):
...
@@ -359,7 +359,8 @@ class SanitizeBoundingBoxes(Transform):
cases.
cases.
Args:
Args:
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
min_size (float, optional): The size below which bounding boxes are removed. Default is 1.
min_area (float, optional): The area below which bounding boxes are removed. Default is 1.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
(or anything else that needs to be sanitized along with the bounding boxes).
(or anything else that needs to be sanitized along with the bounding boxes).
By default, this will try to find a "labels" key in the input (case-insensitive), if
By default, this will try to find a "labels" key in the input (case-insensitive), if
...
@@ -379,6 +380,7 @@ class SanitizeBoundingBoxes(Transform):
...
@@ -379,6 +380,7 @@ class SanitizeBoundingBoxes(Transform):
def
__init__
(
def
__init__
(
self
,
self
,
min_size
:
float
=
1.0
,
min_size
:
float
=
1.0
,
min_area
:
float
=
1.0
,
labels_getter
:
Union
[
Callable
[[
Any
],
Any
],
str
,
None
]
=
"default"
,
labels_getter
:
Union
[
Callable
[[
Any
],
Any
],
str
,
None
]
=
"default"
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -387,6 +389,10 @@ class SanitizeBoundingBoxes(Transform):
...
@@ -387,6 +389,10 @@ class SanitizeBoundingBoxes(Transform):
raise
ValueError
(
f
"min_size must be >= 1, got
{
min_size
}
."
)
raise
ValueError
(
f
"min_size must be >= 1, got
{
min_size
}
."
)
self
.
min_size
=
min_size
self
.
min_size
=
min_size
if
min_area
<
1
:
raise
ValueError
(
f
"min_area must be >= 1, got
{
min_area
}
."
)
self
.
min_area
=
min_area
self
.
labels_getter
=
labels_getter
self
.
labels_getter
=
labels_getter
self
.
_labels_getter
=
_parse_labels_getter
(
labels_getter
)
self
.
_labels_getter
=
_parse_labels_getter
(
labels_getter
)
...
@@ -422,7 +428,9 @@ class SanitizeBoundingBoxes(Transform):
...
@@ -422,7 +428,9 @@ class SanitizeBoundingBoxes(Transform):
format
=
boxes
.
format
,
format
=
boxes
.
format
,
canvas_size
=
boxes
.
canvas_size
,
canvas_size
=
boxes
.
canvas_size
,
min_size
=
self
.
min_size
,
min_size
=
self
.
min_size
,
min_area
=
self
.
min_area
,
)
)
params
=
dict
(
valid
=
valid
,
labels
=
labels
)
params
=
dict
(
valid
=
valid
,
labels
=
labels
)
flat_outputs
=
[
self
.
_transform
(
inpt
,
params
)
for
inpt
in
flat_inputs
]
flat_outputs
=
[
self
.
_transform
(
inpt
,
params
)
for
inpt
in
flat_inputs
]
...
...
torchvision/transforms/v2/functional/_misc.py
View file @
10239873
...
@@ -322,12 +322,13 @@ def sanitize_bounding_boxes(
...
@@ -322,12 +322,13 @@ def sanitize_bounding_boxes(
format
:
Optional
[
tv_tensors
.
BoundingBoxFormat
]
=
None
,
format
:
Optional
[
tv_tensors
.
BoundingBoxFormat
]
=
None
,
canvas_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
canvas_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
min_size
:
float
=
1.0
,
min_size
:
float
=
1.0
,
min_area
:
float
=
1.0
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Remove degenerate/invalid bounding boxes and return the corresponding indexing mask.
"""Remove degenerate/invalid bounding boxes and return the corresponding indexing mask.
This removes bounding boxes that:
This removes bounding boxes that:
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- are below a given ``min_size``
or ``min_area``
: by default this also removes degenerate boxes that have e.g. X2 <= X1.
- have any coordinate outside of their corresponding image. You may want to
- have any coordinate outside of their corresponding image. You may want to
call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals.
call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals.
...
@@ -346,6 +347,7 @@ def sanitize_bounding_boxes(
...
@@ -346,6 +347,7 @@ def sanitize_bounding_boxes(
(size of the corresponding image/video).
(size of the corresponding image/video).
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object.
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object.
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
min_area (float, optional) The area below which bounding boxes are removed. Default is 1.
Returns:
Returns:
out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask.
out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask.
...
@@ -361,7 +363,7 @@ def sanitize_bounding_boxes(
...
@@ -361,7 +363,7 @@ def sanitize_bounding_boxes(
if
isinstance
(
format
,
str
):
if
isinstance
(
format
,
str
):
format
=
tv_tensors
.
BoundingBoxFormat
[
format
.
upper
()]
format
=
tv_tensors
.
BoundingBoxFormat
[
format
.
upper
()]
valid
=
_get_sanitize_bounding_boxes_mask
(
valid
=
_get_sanitize_bounding_boxes_mask
(
bounding_boxes
,
format
=
format
,
canvas_size
=
canvas_size
,
min_size
=
min_size
bounding_boxes
,
format
=
format
,
canvas_size
=
canvas_size
,
min_size
=
min_size
,
min_area
=
min_area
)
)
bounding_boxes
=
bounding_boxes
[
valid
]
bounding_boxes
=
bounding_boxes
[
valid
]
else
:
else
:
...
@@ -374,7 +376,11 @@ def sanitize_bounding_boxes(
...
@@ -374,7 +376,11 @@ def sanitize_bounding_boxes(
"Leave those to None or pass bounding_boxes as a pure tensor."
"Leave those to None or pass bounding_boxes as a pure tensor."
)
)
valid
=
_get_sanitize_bounding_boxes_mask
(
valid
=
_get_sanitize_bounding_boxes_mask
(
bounding_boxes
,
format
=
bounding_boxes
.
format
,
canvas_size
=
bounding_boxes
.
canvas_size
,
min_size
=
min_size
bounding_boxes
,
format
=
bounding_boxes
.
format
,
canvas_size
=
bounding_boxes
.
canvas_size
,
min_size
=
min_size
,
min_area
=
min_area
,
)
)
bounding_boxes
=
tv_tensors
.
wrap
(
bounding_boxes
[
valid
],
like
=
bounding_boxes
)
bounding_boxes
=
tv_tensors
.
wrap
(
bounding_boxes
[
valid
],
like
=
bounding_boxes
)
...
@@ -386,6 +392,7 @@ def _get_sanitize_bounding_boxes_mask(
...
@@ -386,6 +392,7 @@ def _get_sanitize_bounding_boxes_mask(
format
:
tv_tensors
.
BoundingBoxFormat
,
format
:
tv_tensors
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
],
canvas_size
:
Tuple
[
int
,
int
],
min_size
:
float
=
1.0
,
min_size
:
float
=
1.0
,
min_area
:
float
=
1.0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
bounding_boxes
=
_convert_bounding_box_format
(
bounding_boxes
=
_convert_bounding_box_format
(
...
@@ -394,7 +401,7 @@ def _get_sanitize_bounding_boxes_mask(
...
@@ -394,7 +401,7 @@ def _get_sanitize_bounding_boxes_mask(
image_h
,
image_w
=
canvas_size
image_h
,
image_w
=
canvas_size
ws
,
hs
=
bounding_boxes
[:,
2
]
-
bounding_boxes
[:,
0
],
bounding_boxes
[:,
3
]
-
bounding_boxes
[:,
1
]
ws
,
hs
=
bounding_boxes
[:,
2
]
-
bounding_boxes
[:,
0
],
bounding_boxes
[:,
3
]
-
bounding_boxes
[:,
1
]
valid
=
(
ws
>=
min_size
)
&
(
hs
>=
min_size
)
&
(
bounding_boxes
>=
0
).
all
(
dim
=-
1
)
valid
=
(
ws
>=
min_size
)
&
(
hs
>=
min_size
)
&
(
bounding_boxes
>=
0
).
all
(
dim
=-
1
)
&
(
ws
*
hs
>=
min_area
)
# TODO: Do we really need to check for out of bounds here? All
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
# transforms should be clamping anyway, so this should never happen?
image_h
,
image_w
=
canvas_size
image_h
,
image_w
=
canvas_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment