Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1f4a9846
Unverified
Commit
1f4a9846
authored
Feb 14, 2023
by
Philip Meier
Committed by
GitHub
Feb 14, 2023
Browse files
add proper smoke test for prototype transforms (#7238)
parent
0bdd01a7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
192 additions
and
42 deletions
+192
-42
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+192
-42
No files found.
test/test_prototype_transforms.py
View file @
1f4a9846
import
itertools
import
pathlib
import
re
import
warnings
from
collections
import
defaultdict
...
...
@@ -20,15 +21,16 @@ from prototype_common_utils import (
make_image
,
make_images
,
make_label
,
make_masks
,
make_one_hot_labels
,
make_segmentation_mask
,
make_video
,
make_videos
,
)
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision.ops.boxes
import
box_iou
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.prototype.transforms.utils
import
check_type
,
is_simple_tensor
from
torchvision.prototype.transforms
import
functional
as
F
from
torchvision.prototype.transforms.utils
import
check_type
,
is_simple_tensor
,
query_chw
from
torchvision.transforms.functional
import
InterpolationMode
,
pil_to_tensor
,
to_pil_image
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
...
...
@@ -66,53 +68,201 @@ def parametrize(transforms_with_inputs):
)
def
parametrize_from_transforms
(
*
transforms
):
transforms_with_inputs
=
[]
for
transform
in
transforms
:
for
creation_fn
in
[
make_images
,
make_bounding_boxes
,
make_one_hot_labels
,
make_vanilla_tensor_images
,
make_pil_images
,
make_masks
,
make_videos
,
]:
inputs
=
list
(
creation_fn
())
try
:
output
=
transform
(
inputs
[
0
])
except
Exception
:
def
auto_augment_adapter
(
transform
,
input
,
device
):
adapted_input
=
{}
image_or_video_found
=
False
for
key
,
value
in
input
.
items
():
if
isinstance
(
value
,
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
)):
# AA transforms don't support bounding boxes or masks
continue
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoints
.
Video
,
is_simple_tensor
,
PIL
.
Image
.
Image
)):
if
image_or_video_found
:
# AA transforms only support a single image or video
continue
else
:
if
output
is
inputs
[
0
]:
continue
image_or_video_found
=
True
adapted_input
[
key
]
=
value
return
adapted_input
def
linear_transformation_adapter
(
transform
,
input
,
device
):
flat_inputs
=
list
(
input
.
values
())
c
,
h
,
w
=
query_chw
(
[
item
for
item
,
needs_transform
in
zip
(
flat_inputs
,
transforms
.
Transform
().
_needs_transform_list
(
flat_inputs
))
if
needs_transform
]
)
num_elements
=
c
*
h
*
w
transform
.
transformation_matrix
=
torch
.
randn
((
num_elements
,
num_elements
),
device
=
device
)
transform
.
mean_vector
=
torch
.
randn
((
num_elements
,),
device
=
device
)
return
{
key
:
value
for
key
,
value
in
input
.
items
()
if
not
isinstance
(
value
,
PIL
.
Image
.
Image
)}
transforms_with_inputs
.
append
((
transform
,
inputs
))
return
parametrize
(
transforms_with_inputs
)
def
normalize_adapter
(
transform
,
input
,
device
):
adapted_input
=
{}
for
key
,
value
in
input
.
items
():
if
isinstance
(
value
,
PIL
.
Image
.
Image
):
# normalize doesn't support PIL images
continue
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoints
.
Video
,
is_simple_tensor
)):
# normalize doesn't support integer images
value
=
F
.
convert_dtype
(
value
,
torch
.
float32
)
adapted_input
[
key
]
=
value
return
adapted_input
class
TestSmoke
:
@
parametrize_from_transforms
(
transforms
.
RandomErasing
(
p
=
1.0
),
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
transforms
.
CenterCrop
([
16
,
16
]),
transforms
.
ConvertDtype
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
Pad
(
5
),
transforms
.
RandomZoomOut
(),
transforms
.
RandomRotation
(
degrees
=
(
-
45
,
45
)),
transforms
.
RandomAffine
(
degrees
=
(
-
45
,
45
)),
transforms
.
RandomCrop
([
16
,
16
],
padding
=
1
,
pad_if_needed
=
True
),
# TODO: Something wrong with input data setup. Let's fix that
# transforms.RandomEqualize(),
# transforms.RandomInvert(),
# transforms.RandomPosterize(bits=4),
# transforms.RandomSolarize(threshold=0.5),
# transforms.RandomAdjustSharpness(sharpness_factor=0.5),
@
pytest
.
mark
.
parametrize
(
(
"transform"
,
"adapter"
),
[
(
transforms
.
RandomErasing
(
p
=
1.0
),
None
),
(
transforms
.
AugMix
(),
auto_augment_adapter
),
(
transforms
.
AutoAugment
(),
auto_augment_adapter
),
(
transforms
.
RandAugment
(),
auto_augment_adapter
),
(
transforms
.
TrivialAugmentWide
(),
auto_augment_adapter
),
(
transforms
.
ColorJitter
(
brightness
=
0.1
,
contrast
=
0.2
,
saturation
=
0.3
,
hue
=
0.15
),
None
),
(
transforms
.
Grayscale
(),
None
),
(
transforms
.
RandomAdjustSharpness
(
sharpness_factor
=
0.5
,
p
=
1.0
),
None
),
(
transforms
.
RandomAutocontrast
(
p
=
1.0
),
None
),
(
transforms
.
RandomEqualize
(
p
=
1.0
),
None
),
(
transforms
.
RandomGrayscale
(
p
=
1.0
),
None
),
(
transforms
.
RandomInvert
(
p
=
1.0
),
None
),
(
transforms
.
RandomPhotometricDistort
(
p
=
1.0
),
None
),
(
transforms
.
RandomPosterize
(
bits
=
4
,
p
=
1.0
),
None
),
(
transforms
.
RandomSolarize
(
threshold
=
0.5
,
p
=
1.0
),
None
),
(
transforms
.
CenterCrop
([
16
,
16
]),
None
),
(
transforms
.
ElasticTransform
(
sigma
=
1.0
),
None
),
(
transforms
.
Pad
(
4
),
None
),
(
transforms
.
RandomAffine
(
degrees
=
30.0
),
None
),
(
transforms
.
RandomCrop
([
16
,
16
],
pad_if_needed
=
True
),
None
),
(
transforms
.
RandomHorizontalFlip
(
p
=
1.0
),
None
),
(
transforms
.
RandomPerspective
(
p
=
1.0
),
None
),
(
transforms
.
RandomResize
(
min_size
=
10
,
max_size
=
20
),
None
),
(
transforms
.
RandomResizedCrop
([
16
,
16
]),
None
),
(
transforms
.
RandomRotation
(
degrees
=
30
),
None
),
(
transforms
.
RandomShortestSize
(
min_size
=
10
),
None
),
(
transforms
.
RandomVerticalFlip
(
p
=
1.0
),
None
),
(
transforms
.
RandomZoomOut
(
p
=
1.0
),
None
),
(
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
None
),
(
transforms
.
ScaleJitter
((
16
,
16
),
scale_range
=
(
0.8
,
1.2
)),
None
),
(
transforms
.
ClampBoundingBoxes
(),
None
),
(
transforms
.
ConvertBoundingBoxFormat
(
datapoints
.
BoundingBoxFormat
.
CXCYWH
),
None
),
(
transforms
.
ConvertDtype
(),
None
),
(
transforms
.
GaussianBlur
(
kernel_size
=
3
),
None
),
(
transforms
.
LinearTransformation
(
# These are just dummy values that will be filled by the adapter. We can't define them upfront,
# because for we neither know the spatial size nor the device at this point
transformation_matrix
=
torch
.
empty
((
1
,
1
)),
mean_vector
=
torch
.
empty
((
1
,)),
),
linear_transformation_adapter
,
),
(
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
normalize_adapter
),
(
transforms
.
ToDtype
(
torch
.
float64
),
None
),
(
transforms
.
UniformTemporalSubsample
(
num_samples
=
2
),
None
),
],
ids
=
lambda
transform
:
type
(
transform
).
__name__
,
)
def
test_common
(
self
,
transform
,
input
):
transform
(
input
)
@
pytest
.
mark
.
parametrize
(
"container_type"
,
[
dict
,
list
,
tuple
])
@
pytest
.
mark
.
parametrize
(
"image_or_video"
,
[
make_image
(),
make_video
(),
next
(
make_pil_images
(
color_spaces
=
[
"RGB"
])),
next
(
make_vanilla_tensor_images
()),
],
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_common
(
self
,
transform
,
adapter
,
container_type
,
image_or_video
,
device
):
spatial_size
=
F
.
get_spatial_size
(
image_or_video
)
input
=
dict
(
image_or_video
=
image_or_video
,
image_datapoint
=
make_image
(
size
=
spatial_size
),
video_datapoint
=
make_video
(
size
=
spatial_size
),
image_pil
=
next
(
make_pil_images
(
sizes
=
[
spatial_size
],
color_spaces
=
[
"RGB"
])),
bounding_box_xyxy
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
extra_dims
=
(
3
,)
),
bounding_box_xywh
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
spatial_size
=
spatial_size
,
extra_dims
=
(
4
,)
),
bounding_box_cxcywh
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
,
spatial_size
=
spatial_size
,
extra_dims
=
(
5
,)
),
bounding_box_degenerate_xyxy
=
datapoints
.
BoundingBox
(
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
1
,
0
],
# no width
[
2
,
0
,
1
,
1
],
# x1 > x2, y1 < y2
[
0
,
2
,
1
,
1
],
# x1 < x2, y1 > y2
[
2
,
2
,
1
,
1
],
# x1 > x2, y1 > y2
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
),
bounding_box_degenerate_xywh
=
datapoints
.
BoundingBox
(
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
1
,
0
],
# no width
[
0
,
0
,
1
,
-
1
],
# negative height
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
],
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
spatial_size
=
spatial_size
,
),
bounding_box_degenerate_cxcywh
=
datapoints
.
BoundingBox
(
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
1
,
0
],
# no width
[
0
,
0
,
1
,
-
1
],
# negative height
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
],
format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
,
spatial_size
=
spatial_size
,
),
detection_mask
=
make_detection_mask
(
size
=
spatial_size
),
segmentation_mask
=
make_segmentation_mask
(
size
=
spatial_size
),
int
=
0
,
float
=
0.0
,
bool
=
True
,
none
=
None
,
str
=
"str"
,
path
=
pathlib
.
Path
.
cwd
(),
object
=
object
(),
tensor
=
torch
.
empty
(
5
),
array
=
np
.
empty
(
5
),
)
if
adapter
is
not
None
:
input
=
adapter
(
transform
,
input
,
device
)
if
container_type
in
{
tuple
,
list
}:
input
=
container_type
(
input
.
values
())
input_flat
,
input_spec
=
tree_flatten
(
input
)
input_flat
=
[
item
.
to
(
device
)
if
isinstance
(
item
,
torch
.
Tensor
)
else
item
for
item
in
input_flat
]
input
=
tree_unflatten
(
input_flat
,
input_spec
)
torch
.
manual_seed
(
0
)
output
=
transform
(
input
)
output_flat
,
output_spec
=
tree_flatten
(
output
)
assert
output_spec
==
input_spec
for
output_item
,
input_item
,
should_be_transformed
in
zip
(
output_flat
,
input_flat
,
transforms
.
Transform
().
_needs_transform_list
(
input_flat
)
):
if
should_be_transformed
:
assert
type
(
output_item
)
is
type
(
input_item
)
else
:
assert
output_item
is
input_item
@
parametrize
(
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment