Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1e19d73c
Unverified
Commit
1e19d73c
authored
Feb 15, 2023
by
Nicolas Hug
Committed by
GitHub
Feb 15, 2023
Browse files
Add SanitizeBoundingBoxes transform (#7246)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
c5e9a10d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
224 additions
and
21 deletions
+224
-21
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+116
-0
torchvision/prototype/transforms/__init__.py
torchvision/prototype/transforms/__init__.py
+1
-1
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+107
-20
No files found.
test/test_prototype_transforms.py
View file @
1e19d73c
import
itertools
import
pathlib
import
random
import
re
import
warnings
from
collections
import
defaultdict
...
...
@@ -2355,3 +2356,118 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
out
[
"label"
]
=
torch
.
tensor
(
out
[
"label"
])
assert
out
[
"boxes"
].
shape
[
0
]
==
out
[
"masks"
].
shape
[
0
]
==
out
[
"label"
].
shape
[
0
]
==
num_boxes
@
pytest
.
mark
.
parametrize
(
"min_size"
,
(
1
,
10
))
@
pytest
.
mark
.
parametrize
(
"labels_getter"
,
(
"default"
,
"labels"
,
lambda
inputs
:
inputs
[
"labels"
],
None
,
lambda
inputs
:
None
)
)
def
test_sanitize_bounding_boxes
(
min_size
,
labels_getter
):
H
,
W
=
256
,
128
boxes_and_validity
=
[
([
0
,
1
,
10
,
1
],
False
),
# Y1 == Y2
([
0
,
1
,
0
,
20
],
False
),
# X1 == X2
([
0
,
0
,
min_size
-
1
,
10
],
False
),
# H < min_size
([
0
,
0
,
10
,
min_size
-
1
],
False
),
# W < min_size
([
0
,
0
,
10
,
H
+
1
],
False
),
# Y2 > H
([
0
,
0
,
W
+
1
,
10
],
False
),
# X2 > W
([
-
1
,
1
,
10
,
20
],
False
),
# any < 0
([
0
,
0
,
-
1
,
20
],
False
),
# any < 0
([
0
,
0
,
-
10
,
-
1
],
False
),
# any < 0
([
0
,
0
,
min_size
,
10
],
True
),
# H < min_size
([
0
,
0
,
10
,
min_size
],
True
),
# W < min_size
([
0
,
0
,
W
,
H
],
True
),
# TODO: Is that actually OK?? Should it be -1?
([
1
,
1
,
30
,
20
],
True
),
([
0
,
0
,
10
,
10
],
True
),
([
1
,
1
,
30
,
20
],
True
),
]
random
.
shuffle
(
boxes_and_validity
)
# For test robustness: mix order of wrong and correct cases
boxes
,
is_valid_mask
=
zip
(
*
boxes_and_validity
)
valid_indices
=
[
i
for
(
i
,
is_valid
)
in
enumerate
(
is_valid_mask
)
if
is_valid
]
boxes
=
torch
.
tensor
(
boxes
)
labels
=
torch
.
arange
(
boxes
.
shape
[
-
2
])
boxes
=
datapoints
.
BoundingBox
(
boxes
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
H
,
W
),
)
sample
=
{
"image"
:
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
),
"labels"
:
labels
,
"boxes"
:
boxes
,
"whatever"
:
torch
.
rand
(
10
),
"None"
:
None
,
}
out
=
transforms
.
SanitizeBoundingBoxes
(
min_size
=
min_size
,
labels_getter
=
labels_getter
)(
sample
)
assert
out
[
"image"
]
is
sample
[
"image"
]
assert
out
[
"whatever"
]
is
sample
[
"whatever"
]
if
labels_getter
is
None
or
(
callable
(
labels_getter
)
and
labels_getter
({
"labels"
:
"blah"
})
is
None
):
assert
out
[
"labels"
]
is
sample
[
"labels"
]
else
:
assert
isinstance
(
out
[
"labels"
],
torch
.
Tensor
)
assert
out
[
"boxes"
].
shape
[:
-
1
]
==
out
[
"labels"
].
shape
# This works because we conveniently set labels to arange(num_boxes)
assert
out
[
"labels"
].
tolist
()
==
valid_indices
@
pytest
.
mark
.
parametrize
(
"key"
,
(
"labels"
,
"LABELS"
,
"LaBeL"
,
"SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"
))
def
test_sanitize_bounding_boxes_default_heuristic
(
key
):
labels
=
torch
.
arange
(
10
)
d
=
{
key
:
labels
}
assert
transforms
.
SanitizeBoundingBoxes
.
_find_labels_default_heuristic
(
d
)
is
labels
if
key
.
lower
()
!=
"labels"
:
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
d
=
{
key
:
"something_else"
,
"labels"
:
labels
}
assert
transforms
.
SanitizeBoundingBoxes
.
_find_labels_default_heuristic
(
d
)
is
labels
def
test_sanitize_bounding_boxes_errors
():
good_bbox
=
datapoints
.
BoundingBox
(
[[
0
,
0
,
10
,
10
]],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
20
,
20
),
)
with
pytest
.
raises
(
ValueError
,
match
=
"min_size must be >= 1"
):
transforms
.
SanitizeBoundingBoxes
(
min_size
=
0
)
with
pytest
.
raises
(
ValueError
,
match
=
"labels_getter should either be a str"
):
transforms
.
SanitizeBoundingBoxes
(
labels_getter
=
12
)
with
pytest
.
raises
(
ValueError
,
match
=
"Could not infer where the labels are"
):
bad_labels_key
=
{
"bbox"
:
good_bbox
,
"BAD_KEY"
:
torch
.
arange
(
good_bbox
.
shape
[
0
])}
transforms
.
SanitizeBoundingBoxes
()(
bad_labels_key
)
with
pytest
.
raises
(
ValueError
,
match
=
"If labels_getter is a str or 'default'"
):
not_a_dict
=
(
good_bbox
,
torch
.
arange
(
good_bbox
.
shape
[
0
]))
transforms
.
SanitizeBoundingBoxes
()(
not_a_dict
)
with
pytest
.
raises
(
ValueError
,
match
=
"must be a tensor"
):
not_a_tensor
=
{
"bbox"
:
good_bbox
,
"labels"
:
torch
.
arange
(
good_bbox
.
shape
[
0
]).
tolist
()}
transforms
.
SanitizeBoundingBoxes
()(
not_a_tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
"Number of boxes"
):
different_sizes
=
{
"bbox"
:
good_bbox
,
"labels"
:
torch
.
arange
(
good_bbox
.
shape
[
0
]
+
3
)}
transforms
.
SanitizeBoundingBoxes
()(
different_sizes
)
with
pytest
.
raises
(
ValueError
,
match
=
"boxes must be of shape"
):
bad_bbox
=
datapoints
.
BoundingBox
(
# batch with 2 elements
[
[[
0
,
0
,
10
,
10
]],
[[
0
,
0
,
10
,
10
]],
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
20
,
20
),
)
different_sizes
=
{
"bbox"
:
bad_bbox
,
"labels"
:
torch
.
arange
(
bad_bbox
.
shape
[
0
])}
transforms
.
SanitizeBoundingBoxes
()(
different_sizes
)
torchvision/prototype/transforms/__init__.py
View file @
1e19d73c
...
...
@@ -49,7 +49,7 @@ from ._misc import (
LinearTransformation
,
Normalize
,
PermuteDimensions
,
RemoveSmall
BoundingBoxes
,
Sanitize
BoundingBoxes
,
ToDtype
,
TransposeDimensions
,
)
...
...
torchvision/prototype/transforms/_misc.py
View file @
1e19d73c
import
collections
import
warnings
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
contextlib
import
suppress
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
import
PIL.Image
import
torch
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
transforms
as
_transforms
from
torchvision.ops
import
remove_small_boxes
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
,
Transform
...
...
@@ -225,28 +227,113 @@ class TransposeDimensions(Transform):
return
inpt
.
transpose
(
*
dims
)
class
RemoveSmallBoundingBoxes
(
Transform
):
_transformed_types
=
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
,
datapoints
.
Label
,
datapoints
.
OneHotLabel
)
class
SanitizeBoundingBoxes
(
Transform
):
# This removes boxes and their corresponding labels:
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
def
__init__
(
self
,
min_size
:
float
=
1.0
)
->
None
:
def
__init__
(
self
,
min_size
:
float
=
1.0
,
labels_getter
:
Union
[
Callable
[[
Any
],
Optional
[
torch
.
Tensor
]],
str
,
None
]
=
"default"
,
)
->
None
:
super
().
__init__
()
if
min_size
<
1
:
raise
ValueError
(
f
"min_size must be >= 1, got
{
min_size
}
."
)
self
.
min_size
=
min_size
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
bounding_box
=
query_bounding_box
(
flat_inputs
)
# TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to
# be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH
# format,we need to convert first just to afterwards compute the width and height again, although they were
# there in the first place for these formats.
bounding_box
=
F
.
convert_format_bounding_box
(
bounding_box
.
as_subclass
(
torch
.
Tensor
),
old_format
=
bounding_box
.
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
self
.
labels_getter
=
labels_getter
self
.
_labels_getter
:
Optional
[
Callable
[[
Any
],
Optional
[
torch
.
Tensor
]]]
if
labels_getter
==
"default"
:
self
.
_labels_getter
=
self
.
_find_labels_default_heuristic
elif
callable
(
labels_getter
):
self
.
_labels_getter
=
labels_getter
elif
isinstance
(
labels_getter
,
str
):
self
.
_labels_getter
=
lambda
inputs
:
inputs
[
labels_getter
]
elif
labels_getter
is
None
:
self
.
_labels_getter
=
None
else
:
raise
ValueError
(
"labels_getter should either be a str, callable, or 'default'. "
f
"Got
{
labels_getter
}
of type
{
type
(
labels_getter
)
}
."
)
@
staticmethod
def
_find_labels_default_heuristic
(
inputs
:
Dict
[
str
,
Any
])
->
Optional
[
torch
.
Tensor
]:
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
candidate_key
=
None
with
suppress
(
StopIteration
):
candidate_key
=
next
(
key
for
key
in
inputs
.
keys
()
if
key
.
lower
()
==
"labels"
)
if
candidate_key
is
None
:
with
suppress
(
StopIteration
):
candidate_key
=
next
(
key
for
key
in
inputs
.
keys
()
if
"label"
in
key
.
lower
())
if
candidate_key
is
None
:
raise
ValueError
(
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
"If there are no samples and it is by design, pass labels_getter=None."
)
return
inputs
[
candidate_key
]
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
inputs
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
if
isinstance
(
self
.
labels_getter
,
str
)
and
not
isinstance
(
inputs
,
collections
.
abc
.
Mapping
):
raise
ValueError
(
f
"If labels_getter is a str or 'default' (got
{
self
.
labels_getter
}
), "
f
"then the input to forward() must be a dict. Got
{
type
(
inputs
)
}
instead."
)
valid_indices
=
remove_small_boxes
(
bounding_box
,
min_size
=
self
.
min_size
)
return
dict
(
valid_indices
=
valid_indices
)
if
self
.
_labels_getter
is
None
:
labels
=
None
else
:
labels
=
self
.
_labels_getter
(
inputs
)
if
labels
is
not
None
and
not
isinstance
(
labels
,
torch
.
Tensor
):
raise
ValueError
(
f
"The labels in the input to forward() must be a tensor, got
{
type
(
labels
)
}
instead."
)
flat_inputs
,
spec
=
tree_flatten
(
inputs
)
# TODO: this enforces one single BoundingBox entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes
=
query_bounding_box
(
flat_inputs
)
if
boxes
.
ndim
!=
2
:
raise
ValueError
(
f
"boxes must be of shape (num_boxes, 4), got
{
boxes
.
shape
}
"
)
if
labels
is
not
None
and
boxes
.
shape
[
0
]
!=
labels
.
shape
[
0
]:
raise
ValueError
(
f
"Number of boxes (shape=
{
boxes
.
shape
}
) and number of labels (shape=
{
labels
.
shape
}
) do not match."
)
boxes
=
cast
(
datapoints
.
BoundingBox
,
F
.
convert_format_bounding_box
(
boxes
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
),
)
ws
,
hs
=
boxes
[:,
2
]
-
boxes
[:,
0
],
boxes
[:,
3
]
-
boxes
[:,
1
]
mask
=
(
ws
>=
self
.
min_size
)
&
(
hs
>=
self
.
min_size
)
&
(
boxes
>=
0
).
all
(
dim
=-
1
)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h
,
image_w
=
boxes
.
spatial_size
mask
&=
(
boxes
[:,
0
]
<=
image_w
)
&
(
boxes
[:,
2
]
<=
image_w
)
mask
&=
(
boxes
[:,
1
]
<=
image_h
)
&
(
boxes
[:,
3
]
<=
image_h
)
params
=
dict
(
mask
=
mask
,
labels
=
labels
)
flat_outputs
=
[
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels
self
.
_transform
(
inpt
,
params
)
for
inpt
in
flat_inputs
]
return
tree_unflatten
(
flat_outputs
,
spec
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
return
inpt
.
wrap_like
(
inpt
,
inpt
[
params
[
"valid_indices"
]])
if
(
inpt
is
not
None
and
inpt
is
params
[
"labels"
])
or
isinstance
(
inpt
,
datapoints
.
BoundingBox
):
inpt
=
inpt
[
params
[
"mask"
]]
return
inpt
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment