Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b7892d3a
Unverified
Commit
b7892d3a
authored
Feb 16, 2023
by
Nicolas Hug
Committed by
GitHub
Feb 16, 2023
Browse files
Make RandomIoUCrop compatible with SanitizeBoundingBoxes (#7268)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
d4d20f01
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
42 deletions
+47
-42
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+26
-17
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+10
-7
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+3
-12
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+8
-6
No files found.
test/test_prototype_transforms.py
View file @
b7892d3a
...
...
@@ -1488,16 +1488,13 @@ class TestRandomIoUCrop:
fn
.
assert_has_calls
(
expected_calls
)
expected_within_targets
=
sum
(
is_within_crop_area
)
# check number of bboxes vs number of labels:
output_bboxes
=
output
[
1
]
assert
isinstance
(
output_bboxes
,
datapoints
.
BoundingBox
)
assert
len
(
output_bboxes
)
==
expected_within_targets
assert
(
output_bboxes
[
~
is_within_crop_area
]
==
0
).
all
()
output_masks
=
output
[
2
]
assert
isinstance
(
output_masks
,
datapoints
.
Mask
)
assert
len
(
output_masks
)
==
expected_within_targets
class
TestScaleJitter
:
...
...
@@ -2253,10 +2250,11 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
@
pytest
.
mark
.
parametrize
(
"label_type"
,
(
torch
.
Tensor
,
list
))
@
pytest
.
mark
.
parametrize
(
"data_augmentation"
,
(
"hflip"
,
"lsj"
,
"multiscale"
,
"ssd"
,
"ssdlite"
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImageTensor
))
def
test_detection_preset
(
image_type
,
label_type
,
data_augmentation
,
to_tensor
):
@
pytest
.
mark
.
parametrize
(
"sanitize"
,
(
True
,
False
))
def
test_detection_preset
(
image_type
,
data_augmentation
,
to_tensor
,
sanitize
):
torch
.
manual_seed
(
0
)
if
data_augmentation
==
"hflip"
:
t
=
[
transforms
.
RandomHorizontalFlip
(
p
=
1
),
...
...
@@ -2290,20 +2288,20 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
t
=
[
transforms
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomZoomOut
(
fill
=
defaultdict
(
lambda
:
(
123.0
,
117.0
,
104.0
),
{
datapoints
.
Mask
:
0
})),
# TODO: put back IoUCrop once we remove its hard requirement for Labels
# transforms.RandomIoUCrop(),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float
),
]
elif
data_augmentation
==
"ssdlite"
:
t
=
[
# TODO: put back IoUCrop once we remove its hard requirement for Labels
# transforms.RandomIoUCrop(),
transforms
.
RandomIoUCrop
(),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float
),
]
if
sanitize
:
t
+=
[
transforms
.
SanitizeBoundingBoxes
()]
t
=
transforms
.
Compose
(
t
)
num_boxes
=
5
...
...
@@ -2317,10 +2315,7 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
assert
is_simple_tensor
(
image
)
label
=
torch
.
randint
(
0
,
10
,
size
=
(
num_boxes
,))
if
label_type
is
list
:
label
=
label
.
tolist
()
# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks
boxes
=
torch
.
randint
(
0
,
min
(
H
,
W
)
//
2
,
size
=
(
num_boxes
,
4
))
boxes
[:,
2
:]
+=
boxes
[:,
:
2
]
boxes
=
boxes
.
clamp
(
min
=
0
,
max
=
min
(
H
,
W
))
...
...
@@ -2343,8 +2338,19 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
assert
isinstance
(
out
[
"image"
],
datapoints
.
Image
)
assert
isinstance
(
out
[
"label"
],
type
(
sample
[
"label"
]))
out
[
"label"
]
=
torch
.
tensor
(
out
[
"label"
])
assert
out
[
"boxes"
].
shape
[
0
]
==
out
[
"masks"
].
shape
[
0
]
==
out
[
"label"
].
shape
[
0
]
==
num_boxes
num_boxes_expected
=
{
# ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
# doesn't remove them strictly speaking, it just marks some boxes as
# degenerate and those boxes will be later removed by
# SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize
# param is True.
# Note that the values below are probably specific to the random seed
# set above (which is fine).
(
True
,
"ssd"
):
4
,
(
True
,
"ssdlite"
):
4
,
}.
get
((
sanitize
,
data_augmentation
),
num_boxes
)
assert
out
[
"boxes"
].
shape
[
0
]
==
out
[
"masks"
].
shape
[
0
]
==
out
[
"label"
].
shape
[
0
]
==
num_boxes_expected
@
pytest
.
mark
.
parametrize
(
"min_size"
,
(
1
,
10
))
...
...
@@ -2377,7 +2383,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
valid_indices
=
[
i
for
(
i
,
is_valid
)
in
enumerate
(
is_valid_mask
)
if
is_valid
]
boxes
=
torch
.
tensor
(
boxes
)
labels
=
torch
.
arange
(
boxes
.
shape
[
-
2
])
labels
=
torch
.
arange
(
boxes
.
shape
[
0
])
boxes
=
datapoints
.
BoundingBox
(
boxes
,
...
...
@@ -2385,12 +2391,15 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
spatial_size
=
(
H
,
W
),
)
masks
=
datapoints
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
boxes
.
shape
[
0
],
H
,
W
)))
sample
=
{
"image"
:
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
),
"labels"
:
labels
,
"boxes"
:
boxes
,
"whatever"
:
torch
.
rand
(
10
),
"None"
:
None
,
"masks"
:
masks
,
}
out
=
transforms
.
SanitizeBoundingBoxes
(
min_size
=
min_size
,
labels_getter
=
labels_getter
)(
sample
)
...
...
@@ -2402,7 +2411,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
assert
out
[
"labels"
]
is
sample
[
"labels"
]
else
:
assert
isinstance
(
out
[
"labels"
],
torch
.
Tensor
)
assert
out
[
"boxes"
].
shape
[
:
-
1
]
==
out
[
"labels"
].
shape
assert
out
[
"boxes"
].
shape
[
0
]
==
out
[
"labels"
].
shape
[
0
]
==
out
[
"masks"
].
shape
[
0
]
# This works because we conveniently set labels to arange(num_boxes)
assert
out
[
"labels"
].
tolist
()
==
valid_indices
...
...
test/test_prototype_transforms_consistency.py
View file @
b7892d3a
...
...
@@ -1090,13 +1090,16 @@ class TestRefDetTransforms:
"t_ref, t, data_kwargs"
,
[
(
det_transforms
.
RandomHorizontalFlip
(
p
=
1.0
),
v2_transforms
.
RandomHorizontalFlip
(
p
=
1.0
),
{}),
# FIXME: make
# v2_transforms.Compose([
# v2_transforms.RandomIoUCrop(),
# v2_transforms.SanitizeBoundingBoxes()
# ])
# work
# (det_transforms.RandomIoUCrop(), v2_transforms.RandomIoUCrop(), {"with_mask": False}),
(
det_transforms
.
RandomIoUCrop
(),
v2_transforms
.
Compose
(
[
v2_transforms
.
RandomIoUCrop
(),
v2_transforms
.
SanitizeBoundingBoxes
(
labels_getter
=
lambda
sample
:
sample
[
1
][
"labels"
]),
]
),
{
"with_mask"
:
False
},
),
(
det_transforms
.
RandomZoomOut
(),
v2_transforms
.
RandomZoomOut
(),
{
"with_mask"
:
False
}),
(
det_transforms
.
ScaleJitter
((
1024
,
1024
)),
v2_transforms
.
ScaleJitter
((
1024
,
1024
)),
{}),
(
...
...
torchvision/transforms/v2/_geometry.py
View file @
b7892d3a
...
...
@@ -721,8 +721,6 @@ class RandomIoUCrop(Transform):
if
left
==
right
or
top
==
bottom
:
continue
# FIXME: I think we can stop here?
# check for any valid boxes with centers within the crop area
xyxy_bboxes
=
F
.
convert_format_bounding_box
(
bboxes
.
as_subclass
(
torch
.
Tensor
),
bboxes
.
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
...
...
@@ -745,23 +743,16 @@ class RandomIoUCrop(Transform):
return
dict
(
top
=
top
,
left
=
left
,
height
=
new_h
,
width
=
new_w
,
is_within_crop_area
=
is_within_crop_area
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
# FIXME: refactor this to not remove anything
if
len
(
params
)
<
1
:
return
inpt
is_within_crop_area
=
params
[
"is_within_crop_area"
]
output
=
F
.
crop
(
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
])
if
isinstance
(
output
,
datapoints
.
BoundingBox
):
bboxes
=
output
[
is_within_crop_area
]
bboxes
=
F
.
clamp_bounding_box
(
bboxes
,
output
.
format
,
output
.
spatial_size
)
output
=
datapoints
.
BoundingBox
.
wrap_like
(
output
,
bboxes
)
elif
isinstance
(
output
,
datapoints
.
Mask
):
# apply is_within_crop_area if mask is one-hot encoded
masks
=
output
[
is_within_crop_area
]
output
=
datapoints
.
Mask
.
wrap_like
(
output
,
masks
)
# We "mark" the invalid boxes as degenreate, and they can be
# removed by a later call to SanitizeBoundingBoxes()
output
[
~
params
[
"is_within_crop_area"
]]
=
0
return
output
...
...
torchvision/transforms/v2/_misc.py
View file @
b7892d3a
...
...
@@ -265,14 +265,14 @@ class SanitizeBoundingBoxes(Transform):
),
)
ws
,
hs
=
boxes
[:,
2
]
-
boxes
[:,
0
],
boxes
[:,
3
]
-
boxes
[:,
1
]
mask
=
(
ws
>=
self
.
min_size
)
&
(
hs
>=
self
.
min_size
)
&
(
boxes
>=
0
).
all
(
dim
=-
1
)
valid
=
(
ws
>=
self
.
min_size
)
&
(
hs
>=
self
.
min_size
)
&
(
boxes
>=
0
).
all
(
dim
=-
1
)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h
,
image_w
=
boxes
.
spatial_size
mask
&=
(
boxes
[:,
0
]
<=
image_w
)
&
(
boxes
[:,
2
]
<=
image_w
)
mask
&=
(
boxes
[:,
1
]
<=
image_h
)
&
(
boxes
[:,
3
]
<=
image_h
)
valid
&=
(
boxes
[:,
0
]
<=
image_w
)
&
(
boxes
[:,
2
]
<=
image_w
)
valid
&=
(
boxes
[:,
1
]
<=
image_h
)
&
(
boxes
[:,
3
]
<=
image_h
)
params
=
dict
(
mask
=
mask
,
labels
=
labels
)
params
=
dict
(
valid
=
valid
,
labels
=
labels
)
flat_outputs
=
[
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels
...
...
@@ -284,7 +284,9 @@ class SanitizeBoundingBoxes(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
(
inpt
is
not
None
and
inpt
is
params
[
"labels"
])
or
isinstance
(
inpt
,
datapoints
.
BoundingBox
):
inpt
=
inpt
[
params
[
"mask"
]]
if
(
inpt
is
not
None
and
inpt
is
params
[
"labels"
])
or
isinstance
(
inpt
,
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
)
):
inpt
=
inpt
[
params
[
"valid"
]]
return
inpt
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment