Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
fa82fd3b
Unverified
Commit
fa82fd3b
authored
Mar 15, 2024
by
Nicolas Hug
Committed by
GitHub
Mar 15, 2024
Browse files
Allow SanitizeBoundingBoxes to sanitize more labels (#8319)
parent
53869eb8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
21 deletions
+62
-21
test/test_transforms_v2.py
test/test_transforms_v2.py
+24
-2
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+36
-15
torchvision/transforms/v2/_utils.py
torchvision/transforms/v2/_utils.py
+2
-4
No files found.
test/test_transforms_v2.py
View file @
fa82fd3b
...
...
@@ -5706,7 +5706,17 @@ class TestSanitizeBoundingBoxes:
return
boxes
,
expected_valid_mask
@
pytest
.
mark
.
parametrize
(
"min_size"
,
(
1
,
10
))
@
pytest
.
mark
.
parametrize
(
"labels_getter"
,
(
"default"
,
lambda
inputs
:
inputs
[
"labels"
],
None
,
lambda
inputs
:
None
))
@
pytest
.
mark
.
parametrize
(
"labels_getter"
,
(
"default"
,
lambda
inputs
:
inputs
[
"labels"
],
lambda
inputs
:
(
inputs
[
"labels"
],
inputs
[
"other_labels"
]),
lambda
inputs
:
[
inputs
[
"labels"
],
inputs
[
"other_labels"
]],
None
,
lambda
inputs
:
None
,
),
)
@
pytest
.
mark
.
parametrize
(
"sample_type"
,
(
tuple
,
dict
))
def
test_transform
(
self
,
min_size
,
labels_getter
,
sample_type
):
...
...
@@ -5721,12 +5731,16 @@ class TestSanitizeBoundingBoxes:
labels
=
torch
.
arange
(
boxes
.
shape
[
0
])
masks
=
tv_tensors
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
boxes
.
shape
[
0
],
H
,
W
)))
# other_labels corresponds to properties from COCO like iscrowd, area...
# We only sanitize it when labels_getter returns a tuple
other_labels
=
torch
.
arange
(
boxes
.
shape
[
0
])
whatever
=
torch
.
rand
(
10
)
input_img
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
)
sample
=
{
"image"
:
input_img
,
"labels"
:
labels
,
"boxes"
:
boxes
,
"other_labels"
:
other_labels
,
"whatever"
:
whatever
,
"None"
:
None
,
"masks"
:
masks
,
...
...
@@ -5741,12 +5755,14 @@ class TestSanitizeBoundingBoxes:
if
sample_type
is
tuple
:
out_image
=
out
[
0
]
out_labels
=
out
[
1
][
"labels"
]
out_other_labels
=
out
[
1
][
"other_labels"
]
out_boxes
=
out
[
1
][
"boxes"
]
out_masks
=
out
[
1
][
"masks"
]
out_whatever
=
out
[
1
][
"whatever"
]
else
:
out_image
=
out
[
"image"
]
out_labels
=
out
[
"labels"
]
out_other_labels
=
out
[
"other_labels"
]
out_boxes
=
out
[
"boxes"
]
out_masks
=
out
[
"masks"
]
out_whatever
=
out
[
"whatever"
]
...
...
@@ -5757,14 +5773,20 @@ class TestSanitizeBoundingBoxes:
assert
isinstance
(
out_boxes
,
tv_tensors
.
BoundingBoxes
)
assert
isinstance
(
out_masks
,
tv_tensors
.
Mask
)
if
labels_getter
is
None
or
(
callable
(
labels_getter
)
and
labels_getter
(
{
"labels"
:
"blah"
}
)
is
None
):
if
labels_getter
is
None
or
(
callable
(
labels_getter
)
and
labels_getter
(
sample
)
is
None
):
assert
out_labels
is
labels
assert
out_other_labels
is
other_labels
else
:
assert
isinstance
(
out_labels
,
torch
.
Tensor
)
assert
out_boxes
.
shape
[
0
]
==
out_labels
.
shape
[
0
]
==
out_masks
.
shape
[
0
]
# This works because we conveniently set labels to arange(num_boxes)
assert
out_labels
.
tolist
()
==
valid_indices
if
callable
(
labels_getter
)
and
isinstance
(
labels_getter
(
sample
),
(
tuple
,
list
)):
assert_equal
(
out_other_labels
,
out_labels
)
else
:
assert_equal
(
out_other_labels
,
other_labels
)
@
pytest
.
mark
.
parametrize
(
"input_type"
,
(
torch
.
Tensor
,
tv_tensors
.
BoundingBoxes
))
def
test_functional
(
self
,
input_type
):
# Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some
...
...
torchvision/transforms/v2/_misc.py
View file @
fa82fd3b
...
...
@@ -321,6 +321,9 @@ class SanitizeBoundingBoxes(Transform):
- have any coordinate outside of their corresponding image. You may want to
call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals.
It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO
(see ``labels_getter`` parameter).
It is recommended to call it at the end of a pipeline, before passing the
input to the models. It is critical to call this transform if
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
...
...
@@ -330,18 +333,26 @@ class SanitizeBoundingBoxes(Transform):
Args:
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
(or anything else that needs to be sanitized along with the bounding boxes).
By default, this will try to find a "labels" key in the input (case-insensitive), if
the input is a dict or it is a tuple whose second element is a dict.
This heuristic should work well with a lot of datasets, including the built-in torchvision datasets.
It can also be a callable that takes the same input
as the transform, and returns the labels.
It can also be a callable that takes the same input as the transform, and returns either:
- A single tensor (the labels)
- A tuple/list of tensors, each of which will be subject to the same sanitization as the bounding boxes.
This is useful to sanitize multiple tensors like the labels, and the "iscrowd" or "area" properties
from COCO.
If ``labels_getter`` is None then only bounding boxes are sanitized.
"""
def
__init__
(
self
,
min_size
:
float
=
1.0
,
labels_getter
:
Union
[
Callable
[[
Any
],
Optional
[
torch
.
Tensor
]
],
str
,
None
]
=
"default"
,
labels_getter
:
Union
[
Callable
[[
Any
],
Any
],
str
,
None
]
=
"default"
,
)
->
None
:
super
().
__init__
()
...
...
@@ -356,18 +367,28 @@ class SanitizeBoundingBoxes(Transform):
inputs
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
labels
=
self
.
_labels_getter
(
inputs
)
if
labels
is
not
None
and
not
isinstance
(
labels
,
torch
.
Tensor
):
raise
ValueError
(
f
"The labels in the input to forward() must be a tensor or None, got
{
type
(
labels
)
}
instead."
)
if
labels
is
not
None
:
msg
=
"The labels in the input to forward() must be a tensor or None, got {type} instead."
if
isinstance
(
labels
,
torch
.
Tensor
):
labels
=
(
labels
,)
elif
isinstance
(
labels
,
(
tuple
,
list
)):
for
entry
in
labels
:
if
not
isinstance
(
entry
,
torch
.
Tensor
):
# TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask]
raise
ValueError
(
msg
.
format
(
type
=
type
(
entry
)))
else
:
raise
ValueError
(
msg
.
format
(
type
=
type
(
labels
)))
flat_inputs
,
spec
=
tree_flatten
(
inputs
)
boxes
=
get_bounding_boxes
(
flat_inputs
)
if
labels
is
not
None
and
boxes
.
shape
[
0
]
!=
labels
.
shape
[
0
]:
raise
ValueError
(
f
"Number of boxes (shape=
{
boxes
.
shape
}
) and number of labels (shape=
{
labels
.
shape
}
) do not match."
)
if
labels
is
not
None
:
for
label
in
labels
:
if
boxes
.
shape
[
0
]
!=
label
.
shape
[
0
]:
raise
ValueError
(
f
"Number of boxes (shape=
{
boxes
.
shape
}
) and must match the number of labels."
f
"Found labels with shape=
{
label
.
shape
}
)."
)
valid
=
F
.
_misc
.
_get_sanitize_bounding_boxes_mask
(
boxes
,
...
...
@@ -381,7 +402,7 @@ class SanitizeBoundingBoxes(Transform):
return
tree_unflatten
(
flat_outputs
,
spec
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
is_label
=
inpt
is
not
None
and
inpt
is
params
[
"labels"
]
is_label
=
params
[
"labels"
]
is
not
None
and
any
(
inpt
is
label
for
label
in
params
[
"labels"
]
)
is_bounding_boxes_or_mask
=
isinstance
(
inpt
,
(
tv_tensors
.
BoundingBoxes
,
tv_tensors
.
Mask
))
if
not
(
is_label
or
is_bounding_boxes_or_mask
):
...
...
@@ -391,5 +412,5 @@ class SanitizeBoundingBoxes(Transform):
if
is_label
:
return
output
return
tv_tensors
.
wrap
(
output
,
like
=
inpt
)
else
:
return
tv_tensors
.
wrap
(
output
,
like
=
inpt
)
torchvision/transforms/v2/_utils.py
View file @
fa82fd3b
...
...
@@ -4,7 +4,7 @@ import collections.abc
import
numbers
from
contextlib
import
suppress
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Literal
,
Sequence
,
Tuple
,
Type
,
Union
import
PIL.Image
import
torch
...
...
@@ -139,9 +139,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
return
inputs
[
candidate_key
]
def
_parse_labels_getter
(
labels_getter
:
Union
[
str
,
Callable
[[
Any
],
Optional
[
torch
.
Tensor
]],
None
]
)
->
Callable
[[
Any
],
Optional
[
torch
.
Tensor
]]:
def
_parse_labels_getter
(
labels_getter
:
Union
[
str
,
Callable
[[
Any
],
Any
],
None
])
->
Callable
[[
Any
],
Any
]:
if
labels_getter
==
"default"
:
return
_find_labels_default_heuristic
elif
callable
(
labels_getter
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment