Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1f4a9846
"vscode:/vscode.git/clone" did not exist on "53f33def362e9f4596d165dc109e153034d4f574"
Unverified
Commit
1f4a9846
authored
Feb 14, 2023
by
Philip Meier
Committed by
GitHub
Feb 14, 2023
Browse files
add proper smoke test for prototype transforms (#7238)
parent
0bdd01a7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
192 additions
and
42 deletions
+192
-42
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+192
-42
No files found.
test/test_prototype_transforms.py
View file @
1f4a9846
import
itertools
import
pathlib
import
re
import
warnings
from
collections
import
defaultdict
...
...
@@ -20,15 +21,16 @@ from prototype_common_utils import (
make_image
,
make_images
,
make_label
,
make_masks
,
make_one_hot_labels
,
make_segmentation_mask
,
make_video
,
make_videos
,
)
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision.ops.boxes
import
box_iou
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.prototype.transforms.utils
import
check_type
,
is_simple_tensor
from
torchvision.prototype.transforms
import
functional
as
F
from
torchvision.prototype.transforms.utils
import
check_type
,
is_simple_tensor
,
query_chw
from
torchvision.transforms.functional
import
InterpolationMode
,
pil_to_tensor
,
to_pil_image
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
...
...
@@ -66,53 +68,201 @@ def parametrize(transforms_with_inputs):
)
def
parametrize_from_transforms
(
*
transforms
):
transforms_with_inputs
=
[]
for
transform
in
transforms
:
for
creation_fn
in
[
make_images
,
make_bounding_boxes
,
make_one_hot_labels
,
make_vanilla_tensor_images
,
make_pil_images
,
make_masks
,
make_videos
,
]:
inputs
=
list
(
creation_fn
())
try
:
output
=
transform
(
inputs
[
0
])
except
Exception
:
def
auto_augment_adapter
(
transform
,
input
,
device
):
adapted_input
=
{}
image_or_video_found
=
False
for
key
,
value
in
input
.
items
():
if
isinstance
(
value
,
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
)):
# AA transforms don't support bounding boxes or masks
continue
else
:
if
output
is
inputs
[
0
]:
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoints
.
Video
,
is_simple_tensor
,
PIL
.
Image
.
Image
)):
if
image_or_video_found
:
# AA transforms only support a single image or video
continue
image_or_video_found
=
True
adapted_input
[
key
]
=
value
return
adapted_input
transforms_with_inputs
.
append
((
transform
,
inputs
))
def
linear_transformation_adapter
(
transform
,
input
,
device
):
flat_inputs
=
list
(
input
.
values
())
c
,
h
,
w
=
query_chw
(
[
item
for
item
,
needs_transform
in
zip
(
flat_inputs
,
transforms
.
Transform
().
_needs_transform_list
(
flat_inputs
))
if
needs_transform
]
)
num_elements
=
c
*
h
*
w
transform
.
transformation_matrix
=
torch
.
randn
((
num_elements
,
num_elements
),
device
=
device
)
transform
.
mean_vector
=
torch
.
randn
((
num_elements
,),
device
=
device
)
return
{
key
:
value
for
key
,
value
in
input
.
items
()
if
not
isinstance
(
value
,
PIL
.
Image
.
Image
)}
return
parametrize
(
transforms_with_inputs
)
def
normalize_adapter
(
transform
,
input
,
device
):
adapted_input
=
{}
for
key
,
value
in
input
.
items
():
if
isinstance
(
value
,
PIL
.
Image
.
Image
):
# normalize doesn't support PIL images
continue
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoints
.
Video
,
is_simple_tensor
)):
# normalize doesn't support integer images
value
=
F
.
convert_dtype
(
value
,
torch
.
float32
)
adapted_input
[
key
]
=
value
return
adapted_input
class
TestSmoke
:
@
parametrize_from_transforms
(
transforms
.
RandomErasing
(
p
=
1.0
),
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
transforms
.
CenterCrop
([
16
,
16
]),
transforms
.
ConvertDtype
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
Pad
(
5
),
transforms
.
RandomZoomOut
(),
transforms
.
RandomRotation
(
degrees
=
(
-
45
,
45
)),
transforms
.
RandomAffine
(
degrees
=
(
-
45
,
45
)),
transforms
.
RandomCrop
([
16
,
16
],
padding
=
1
,
pad_if_needed
=
True
),
# TODO: Something wrong with input data setup. Let's fix that
# transforms.RandomEqualize(),
# transforms.RandomInvert(),
# transforms.RandomPosterize(bits=4),
# transforms.RandomSolarize(threshold=0.5),
# transforms.RandomAdjustSharpness(sharpness_factor=0.5),
@
pytest
.
mark
.
parametrize
(
(
"transform"
,
"adapter"
),
[
(
transforms
.
RandomErasing
(
p
=
1.0
),
None
),
(
transforms
.
AugMix
(),
auto_augment_adapter
),
(
transforms
.
AutoAugment
(),
auto_augment_adapter
),
(
transforms
.
RandAugment
(),
auto_augment_adapter
),
(
transforms
.
TrivialAugmentWide
(),
auto_augment_adapter
),
(
transforms
.
ColorJitter
(
brightness
=
0.1
,
contrast
=
0.2
,
saturation
=
0.3
,
hue
=
0.15
),
None
),
(
transforms
.
Grayscale
(),
None
),
(
transforms
.
RandomAdjustSharpness
(
sharpness_factor
=
0.5
,
p
=
1.0
),
None
),
(
transforms
.
RandomAutocontrast
(
p
=
1.0
),
None
),
(
transforms
.
RandomEqualize
(
p
=
1.0
),
None
),
(
transforms
.
RandomGrayscale
(
p
=
1.0
),
None
),
(
transforms
.
RandomInvert
(
p
=
1.0
),
None
),
(
transforms
.
RandomPhotometricDistort
(
p
=
1.0
),
None
),
(
transforms
.
RandomPosterize
(
bits
=
4
,
p
=
1.0
),
None
),
(
transforms
.
RandomSolarize
(
threshold
=
0.5
,
p
=
1.0
),
None
),
(
transforms
.
CenterCrop
([
16
,
16
]),
None
),
(
transforms
.
ElasticTransform
(
sigma
=
1.0
),
None
),
(
transforms
.
Pad
(
4
),
None
),
(
transforms
.
RandomAffine
(
degrees
=
30.0
),
None
),
(
transforms
.
RandomCrop
([
16
,
16
],
pad_if_needed
=
True
),
None
),
(
transforms
.
RandomHorizontalFlip
(
p
=
1.0
),
None
),
(
transforms
.
RandomPerspective
(
p
=
1.0
),
None
),
(
transforms
.
RandomResize
(
min_size
=
10
,
max_size
=
20
),
None
),
(
transforms
.
RandomResizedCrop
([
16
,
16
]),
None
),
(
transforms
.
RandomRotation
(
degrees
=
30
),
None
),
(
transforms
.
RandomShortestSize
(
min_size
=
10
),
None
),
(
transforms
.
RandomVerticalFlip
(
p
=
1.0
),
None
),
(
transforms
.
RandomZoomOut
(
p
=
1.0
),
None
),
(
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
None
),
(
transforms
.
ScaleJitter
((
16
,
16
),
scale_range
=
(
0.8
,
1.2
)),
None
),
(
transforms
.
ClampBoundingBoxes
(),
None
),
(
transforms
.
ConvertBoundingBoxFormat
(
datapoints
.
BoundingBoxFormat
.
CXCYWH
),
None
),
(
transforms
.
ConvertDtype
(),
None
),
(
transforms
.
GaussianBlur
(
kernel_size
=
3
),
None
),
(
transforms
.
LinearTransformation
(
# These are just dummy values that will be filled by the adapter. We can't define them upfront,
# because for we neither know the spatial size nor the device at this point
transformation_matrix
=
torch
.
empty
((
1
,
1
)),
mean_vector
=
torch
.
empty
((
1
,)),
),
linear_transformation_adapter
,
),
(
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
normalize_adapter
),
(
transforms
.
ToDtype
(
torch
.
float64
),
None
),
(
transforms
.
UniformTemporalSubsample
(
num_samples
=
2
),
None
),
],
ids
=
lambda
transform
:
type
(
transform
).
__name__
,
)
def
test_common
(
self
,
transform
,
input
):
transform
(
input
)
@
pytest
.
mark
.
parametrize
(
"container_type"
,
[
dict
,
list
,
tuple
])
@
pytest
.
mark
.
parametrize
(
"image_or_video"
,
[
make_image
(),
make_video
(),
next
(
make_pil_images
(
color_spaces
=
[
"RGB"
])),
next
(
make_vanilla_tensor_images
()),
],
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_common
(
self
,
transform
,
adapter
,
container_type
,
image_or_video
,
device
):
spatial_size
=
F
.
get_spatial_size
(
image_or_video
)
input
=
dict
(
image_or_video
=
image_or_video
,
image_datapoint
=
make_image
(
size
=
spatial_size
),
video_datapoint
=
make_video
(
size
=
spatial_size
),
image_pil
=
next
(
make_pil_images
(
sizes
=
[
spatial_size
],
color_spaces
=
[
"RGB"
])),
bounding_box_xyxy
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
extra_dims
=
(
3
,)
),
bounding_box_xywh
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
spatial_size
=
spatial_size
,
extra_dims
=
(
4
,)
),
bounding_box_cxcywh
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
,
spatial_size
=
spatial_size
,
extra_dims
=
(
5
,)
),
bounding_box_degenerate_xyxy
=
datapoints
.
BoundingBox
(
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
1
,
0
],
# no width
[
2
,
0
,
1
,
1
],
# x1 > x2, y1 < y2
[
0
,
2
,
1
,
1
],
# x1 < x2, y1 > y2
[
2
,
2
,
1
,
1
],
# x1 > x2, y1 > y2
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
),
bounding_box_degenerate_xywh
=
datapoints
.
BoundingBox
(
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
1
,
0
],
# no width
[
0
,
0
,
1
,
-
1
],
# negative height
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
],
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
spatial_size
=
spatial_size
,
),
bounding_box_degenerate_cxcywh
=
datapoints
.
BoundingBox
(
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
[
0
,
0
,
1
,
0
],
# no width
[
0
,
0
,
1
,
-
1
],
# negative height
[
0
,
0
,
-
1
,
1
],
# negative width
[
0
,
0
,
-
1
,
-
1
],
# negative height and width
],
format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
,
spatial_size
=
spatial_size
,
),
detection_mask
=
make_detection_mask
(
size
=
spatial_size
),
segmentation_mask
=
make_segmentation_mask
(
size
=
spatial_size
),
int
=
0
,
float
=
0.0
,
bool
=
True
,
none
=
None
,
str
=
"str"
,
path
=
pathlib
.
Path
.
cwd
(),
object
=
object
(),
tensor
=
torch
.
empty
(
5
),
array
=
np
.
empty
(
5
),
)
if
adapter
is
not
None
:
input
=
adapter
(
transform
,
input
,
device
)
if
container_type
in
{
tuple
,
list
}:
input
=
container_type
(
input
.
values
())
input_flat
,
input_spec
=
tree_flatten
(
input
)
input_flat
=
[
item
.
to
(
device
)
if
isinstance
(
item
,
torch
.
Tensor
)
else
item
for
item
in
input_flat
]
input
=
tree_unflatten
(
input_flat
,
input_spec
)
torch
.
manual_seed
(
0
)
output
=
transform
(
input
)
output_flat
,
output_spec
=
tree_flatten
(
output
)
assert
output_spec
==
input_spec
for
output_item
,
input_item
,
should_be_transformed
in
zip
(
output_flat
,
input_flat
,
transforms
.
Transform
().
_needs_transform_list
(
input_flat
)
):
if
should_be_transformed
:
assert
type
(
output_item
)
is
type
(
input_item
)
else
:
assert
output_item
is
input_item
@
parametrize
(
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment