Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b7892d3a
Unverified
Commit
b7892d3a
authored
Feb 16, 2023
by
Nicolas Hug
Committed by
GitHub
Feb 16, 2023
Browse files
Make RandomIoUCrop compatible with SanitizeBoundingBoxes (#7268)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
d4d20f01
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
42 deletions
+47
-42
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+26
-17
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+10
-7
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+3
-12
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+8
-6
No files found.
test/test_prototype_transforms.py
View file @
b7892d3a
...
@@ -1488,16 +1488,13 @@ class TestRandomIoUCrop:
...
@@ -1488,16 +1488,13 @@ class TestRandomIoUCrop:
fn
.
assert_has_calls
(
expected_calls
)
fn
.
assert_has_calls
(
expected_calls
)
expected_within_targets
=
sum
(
is_within_crop_area
)
# check number of bboxes vs number of labels:
# check number of bboxes vs number of labels:
output_bboxes
=
output
[
1
]
output_bboxes
=
output
[
1
]
assert
isinstance
(
output_bboxes
,
datapoints
.
BoundingBox
)
assert
isinstance
(
output_bboxes
,
datapoints
.
BoundingBox
)
assert
len
(
output_bboxes
)
==
expected_within_targets
assert
(
output_bboxes
[
~
is_within_crop_area
]
==
0
).
all
()
output_masks
=
output
[
2
]
output_masks
=
output
[
2
]
assert
isinstance
(
output_masks
,
datapoints
.
Mask
)
assert
isinstance
(
output_masks
,
datapoints
.
Mask
)
assert
len
(
output_masks
)
==
expected_within_targets
class
TestScaleJitter
:
class
TestScaleJitter
:
...
@@ -2253,10 +2250,11 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
...
@@ -2253,10 +2250,11 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
@
pytest
.
mark
.
parametrize
(
"label_type"
,
(
torch
.
Tensor
,
list
))
@
pytest
.
mark
.
parametrize
(
"data_augmentation"
,
(
"hflip"
,
"lsj"
,
"multiscale"
,
"ssd"
,
"ssdlite"
))
@
pytest
.
mark
.
parametrize
(
"data_augmentation"
,
(
"hflip"
,
"lsj"
,
"multiscale"
,
"ssd"
,
"ssdlite"
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImageTensor
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImageTensor
))
def
test_detection_preset
(
image_type
,
label_type
,
data_augmentation
,
to_tensor
):
@
pytest
.
mark
.
parametrize
(
"sanitize"
,
(
True
,
False
))
def
test_detection_preset
(
image_type
,
data_augmentation
,
to_tensor
,
sanitize
):
torch
.
manual_seed
(
0
)
if
data_augmentation
==
"hflip"
:
if
data_augmentation
==
"hflip"
:
t
=
[
t
=
[
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
...
@@ -2290,20 +2288,20 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
...
@@ -2290,20 +2288,20 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
t
=
[
t
=
[
transforms
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomZoomOut
(
fill
=
defaultdict
(
lambda
:
(
123.0
,
117.0
,
104.0
),
{
datapoints
.
Mask
:
0
})),
transforms
.
RandomZoomOut
(
fill
=
defaultdict
(
lambda
:
(
123.0
,
117.0
,
104.0
),
{
datapoints
.
Mask
:
0
})),
# TODO: put back IoUCrop once we remove its hard requirement for Labels
transforms
.
RandomIoUCrop
(),
# transforms.RandomIoUCrop(),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
(),
to_tensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float
),
transforms
.
ConvertImageDtype
(
torch
.
float
),
]
]
elif
data_augmentation
==
"ssdlite"
:
elif
data_augmentation
==
"ssdlite"
:
t
=
[
t
=
[
# TODO: put back IoUCrop once we remove its hard requirement for Labels
transforms
.
RandomIoUCrop
(),
# transforms.RandomIoUCrop(),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
(),
to_tensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float
),
transforms
.
ConvertImageDtype
(
torch
.
float
),
]
]
if
sanitize
:
t
+=
[
transforms
.
SanitizeBoundingBoxes
()]
t
=
transforms
.
Compose
(
t
)
t
=
transforms
.
Compose
(
t
)
num_boxes
=
5
num_boxes
=
5
...
@@ -2317,10 +2315,7 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
...
@@ -2317,10 +2315,7 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
assert
is_simple_tensor
(
image
)
assert
is_simple_tensor
(
image
)
label
=
torch
.
randint
(
0
,
10
,
size
=
(
num_boxes
,))
label
=
torch
.
randint
(
0
,
10
,
size
=
(
num_boxes
,))
if
label_type
is
list
:
label
=
label
.
tolist
()
# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks
boxes
=
torch
.
randint
(
0
,
min
(
H
,
W
)
//
2
,
size
=
(
num_boxes
,
4
))
boxes
=
torch
.
randint
(
0
,
min
(
H
,
W
)
//
2
,
size
=
(
num_boxes
,
4
))
boxes
[:,
2
:]
+=
boxes
[:,
:
2
]
boxes
[:,
2
:]
+=
boxes
[:,
:
2
]
boxes
=
boxes
.
clamp
(
min
=
0
,
max
=
min
(
H
,
W
))
boxes
=
boxes
.
clamp
(
min
=
0
,
max
=
min
(
H
,
W
))
...
@@ -2343,8 +2338,19 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
...
@@ -2343,8 +2338,19 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
assert
isinstance
(
out
[
"image"
],
datapoints
.
Image
)
assert
isinstance
(
out
[
"image"
],
datapoints
.
Image
)
assert
isinstance
(
out
[
"label"
],
type
(
sample
[
"label"
]))
assert
isinstance
(
out
[
"label"
],
type
(
sample
[
"label"
]))
out
[
"label"
]
=
torch
.
tensor
(
out
[
"label"
])
num_boxes_expected
=
{
assert
out
[
"boxes"
].
shape
[
0
]
==
out
[
"masks"
].
shape
[
0
]
==
out
[
"label"
].
shape
[
0
]
==
num_boxes
# ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
# doesn't remove them strictly speaking, it just marks some boxes as
# degenerate and those boxes will be later removed by
# SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize
# param is True.
# Note that the values below are probably specific to the random seed
# set above (which is fine).
(
True
,
"ssd"
):
4
,
(
True
,
"ssdlite"
):
4
,
}.
get
((
sanitize
,
data_augmentation
),
num_boxes
)
assert
out
[
"boxes"
].
shape
[
0
]
==
out
[
"masks"
].
shape
[
0
]
==
out
[
"label"
].
shape
[
0
]
==
num_boxes_expected
@
pytest
.
mark
.
parametrize
(
"min_size"
,
(
1
,
10
))
@
pytest
.
mark
.
parametrize
(
"min_size"
,
(
1
,
10
))
...
@@ -2377,7 +2383,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
...
@@ -2377,7 +2383,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
valid_indices
=
[
i
for
(
i
,
is_valid
)
in
enumerate
(
is_valid_mask
)
if
is_valid
]
valid_indices
=
[
i
for
(
i
,
is_valid
)
in
enumerate
(
is_valid_mask
)
if
is_valid
]
boxes
=
torch
.
tensor
(
boxes
)
boxes
=
torch
.
tensor
(
boxes
)
labels
=
torch
.
arange
(
boxes
.
shape
[
-
2
])
labels
=
torch
.
arange
(
boxes
.
shape
[
0
])
boxes
=
datapoints
.
BoundingBox
(
boxes
=
datapoints
.
BoundingBox
(
boxes
,
boxes
,
...
@@ -2385,12 +2391,15 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
...
@@ -2385,12 +2391,15 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
spatial_size
=
(
H
,
W
),
spatial_size
=
(
H
,
W
),
)
)
masks
=
datapoints
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
boxes
.
shape
[
0
],
H
,
W
)))
sample
=
{
sample
=
{
"image"
:
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
),
"image"
:
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
),
"labels"
:
labels
,
"labels"
:
labels
,
"boxes"
:
boxes
,
"boxes"
:
boxes
,
"whatever"
:
torch
.
rand
(
10
),
"whatever"
:
torch
.
rand
(
10
),
"None"
:
None
,
"None"
:
None
,
"masks"
:
masks
,
}
}
out
=
transforms
.
SanitizeBoundingBoxes
(
min_size
=
min_size
,
labels_getter
=
labels_getter
)(
sample
)
out
=
transforms
.
SanitizeBoundingBoxes
(
min_size
=
min_size
,
labels_getter
=
labels_getter
)(
sample
)
...
@@ -2402,7 +2411,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
...
@@ -2402,7 +2411,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
assert
out
[
"labels"
]
is
sample
[
"labels"
]
assert
out
[
"labels"
]
is
sample
[
"labels"
]
else
:
else
:
assert
isinstance
(
out
[
"labels"
],
torch
.
Tensor
)
assert
isinstance
(
out
[
"labels"
],
torch
.
Tensor
)
assert
out
[
"boxes"
].
shape
[
:
-
1
]
==
out
[
"labels"
].
shape
assert
out
[
"boxes"
].
shape
[
0
]
==
out
[
"labels"
].
shape
[
0
]
==
out
[
"masks"
].
shape
[
0
]
# This works because we conveniently set labels to arange(num_boxes)
# This works because we conveniently set labels to arange(num_boxes)
assert
out
[
"labels"
].
tolist
()
==
valid_indices
assert
out
[
"labels"
].
tolist
()
==
valid_indices
...
...
test/test_prototype_transforms_consistency.py
View file @
b7892d3a
...
@@ -1090,13 +1090,16 @@ class TestRefDetTransforms:
...
@@ -1090,13 +1090,16 @@ class TestRefDetTransforms:
"t_ref, t, data_kwargs"
,
"t_ref, t, data_kwargs"
,
[
[
(
det_transforms
.
RandomHorizontalFlip
(
p
=
1.0
),
v2_transforms
.
RandomHorizontalFlip
(
p
=
1.0
),
{}),
(
det_transforms
.
RandomHorizontalFlip
(
p
=
1.0
),
v2_transforms
.
RandomHorizontalFlip
(
p
=
1.0
),
{}),
# FIXME: make
(
# v2_transforms.Compose([
det_transforms
.
RandomIoUCrop
(),
# v2_transforms.RandomIoUCrop(),
v2_transforms
.
Compose
(
# v2_transforms.SanitizeBoundingBoxes()
[
# ])
v2_transforms
.
RandomIoUCrop
(),
# work
v2_transforms
.
SanitizeBoundingBoxes
(
labels_getter
=
lambda
sample
:
sample
[
1
][
"labels"
]),
# (det_transforms.RandomIoUCrop(), v2_transforms.RandomIoUCrop(), {"with_mask": False}),
]
),
{
"with_mask"
:
False
},
),
(
det_transforms
.
RandomZoomOut
(),
v2_transforms
.
RandomZoomOut
(),
{
"with_mask"
:
False
}),
(
det_transforms
.
RandomZoomOut
(),
v2_transforms
.
RandomZoomOut
(),
{
"with_mask"
:
False
}),
(
det_transforms
.
ScaleJitter
((
1024
,
1024
)),
v2_transforms
.
ScaleJitter
((
1024
,
1024
)),
{}),
(
det_transforms
.
ScaleJitter
((
1024
,
1024
)),
v2_transforms
.
ScaleJitter
((
1024
,
1024
)),
{}),
(
(
...
...
torchvision/transforms/v2/_geometry.py
View file @
b7892d3a
...
@@ -721,8 +721,6 @@ class RandomIoUCrop(Transform):
...
@@ -721,8 +721,6 @@ class RandomIoUCrop(Transform):
if
left
==
right
or
top
==
bottom
:
if
left
==
right
or
top
==
bottom
:
continue
continue
# FIXME: I think we can stop here?
# check for any valid boxes with centers within the crop area
# check for any valid boxes with centers within the crop area
xyxy_bboxes
=
F
.
convert_format_bounding_box
(
xyxy_bboxes
=
F
.
convert_format_bounding_box
(
bboxes
.
as_subclass
(
torch
.
Tensor
),
bboxes
.
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
bboxes
.
as_subclass
(
torch
.
Tensor
),
bboxes
.
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
...
@@ -745,23 +743,16 @@ class RandomIoUCrop(Transform):
...
@@ -745,23 +743,16 @@ class RandomIoUCrop(Transform):
return
dict
(
top
=
top
,
left
=
left
,
height
=
new_h
,
width
=
new_w
,
is_within_crop_area
=
is_within_crop_area
)
return
dict
(
top
=
top
,
left
=
left
,
height
=
new_h
,
width
=
new_w
,
is_within_crop_area
=
is_within_crop_area
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
# FIXME: refactor this to not remove anything
if
len
(
params
)
<
1
:
if
len
(
params
)
<
1
:
return
inpt
return
inpt
is_within_crop_area
=
params
[
"is_within_crop_area"
]
output
=
F
.
crop
(
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
])
output
=
F
.
crop
(
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
])
if
isinstance
(
output
,
datapoints
.
BoundingBox
):
if
isinstance
(
output
,
datapoints
.
BoundingBox
):
bboxes
=
output
[
is_within_crop_area
]
# We "mark" the invalid boxes as degenreate, and they can be
bboxes
=
F
.
clamp_bounding_box
(
bboxes
,
output
.
format
,
output
.
spatial_size
)
# removed by a later call to SanitizeBoundingBoxes()
output
=
datapoints
.
BoundingBox
.
wrap_like
(
output
,
bboxes
)
output
[
~
params
[
"is_within_crop_area"
]]
=
0
elif
isinstance
(
output
,
datapoints
.
Mask
):
# apply is_within_crop_area if mask is one-hot encoded
masks
=
output
[
is_within_crop_area
]
output
=
datapoints
.
Mask
.
wrap_like
(
output
,
masks
)
return
output
return
output
...
...
torchvision/transforms/v2/_misc.py
View file @
b7892d3a
...
@@ -265,14 +265,14 @@ class SanitizeBoundingBoxes(Transform):
...
@@ -265,14 +265,14 @@ class SanitizeBoundingBoxes(Transform):
),
),
)
)
ws
,
hs
=
boxes
[:,
2
]
-
boxes
[:,
0
],
boxes
[:,
3
]
-
boxes
[:,
1
]
ws
,
hs
=
boxes
[:,
2
]
-
boxes
[:,
0
],
boxes
[:,
3
]
-
boxes
[:,
1
]
mask
=
(
ws
>=
self
.
min_size
)
&
(
hs
>=
self
.
min_size
)
&
(
boxes
>=
0
).
all
(
dim
=-
1
)
valid
=
(
ws
>=
self
.
min_size
)
&
(
hs
>=
self
.
min_size
)
&
(
boxes
>=
0
).
all
(
dim
=-
1
)
# TODO: Do we really need to check for out of bounds here? All
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
# transforms should be clamping anyway, so this should never happen?
image_h
,
image_w
=
boxes
.
spatial_size
image_h
,
image_w
=
boxes
.
spatial_size
mask
&=
(
boxes
[:,
0
]
<=
image_w
)
&
(
boxes
[:,
2
]
<=
image_w
)
valid
&=
(
boxes
[:,
0
]
<=
image_w
)
&
(
boxes
[:,
2
]
<=
image_w
)
mask
&=
(
boxes
[:,
1
]
<=
image_h
)
&
(
boxes
[:,
3
]
<=
image_h
)
valid
&=
(
boxes
[:,
1
]
<=
image_h
)
&
(
boxes
[:,
3
]
<=
image_h
)
params
=
dict
(
mask
=
mask
,
labels
=
labels
)
params
=
dict
(
valid
=
valid
,
labels
=
labels
)
flat_outputs
=
[
flat_outputs
=
[
# Even-though it may look like we're transforming all inputs, we don't:
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels
# _transform() will only care about BoundingBoxes and the labels
...
@@ -284,7 +284,9 @@ class SanitizeBoundingBoxes(Transform):
...
@@ -284,7 +284,9 @@ class SanitizeBoundingBoxes(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
(
inpt
is
not
None
and
inpt
is
params
[
"labels"
])
or
isinstance
(
inpt
,
datapoints
.
BoundingBox
):
if
(
inpt
is
not
None
and
inpt
is
params
[
"labels"
])
or
isinstance
(
inpt
=
inpt
[
params
[
"mask"
]]
inpt
,
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
)
):
inpt
=
inpt
[
params
[
"valid"
]]
return
inpt
return
inpt
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment