Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
0aed8329
Unverified
Commit
0aed8329
authored
Feb 13, 2023
by
Nicolas Hug
Committed by
GitHub
Feb 13, 2023
Browse files
Add tests for transform presets, and various fixes (#7223)
parent
c73411a4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
179 additions
and
11 deletions
+179
-11
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+152
-0
torchvision/prototype/transforms/_auto_augment.py
torchvision/prototype/transforms/_auto_augment.py
+3
-2
torchvision/prototype/transforms/_color.py
torchvision/prototype/transforms/_color.py
+2
-1
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+22
-8
No files found.
test/test_prototype_transforms.py
View file @
0aed8329
import
itertools
import
re
from
collections
import
defaultdict
import
numpy
as
np
...
...
@@ -1988,3 +1989,154 @@ class TestUniformTemporalSubsample:
assert
type
(
output
)
is
type
(
inpt
)
assert
output
.
shape
[
-
4
]
==
num_samples
assert
output
.
dtype
==
inpt
.
dtype
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
@
pytest
.
mark
.
parametrize
(
"label_type"
,
(
torch
.
Tensor
,
int
))
@
pytest
.
mark
.
parametrize
(
"dataset_return_type"
,
(
dict
,
tuple
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImageTensor
))
def
test_classif_preset
(
image_type
,
label_type
,
dataset_return_type
,
to_tensor
):
image
=
datapoints
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
250
,
250
),
dtype
=
torch
.
uint8
))
if
image_type
is
PIL
.
Image
:
image
=
to_pil_image
(
image
[
0
])
elif
image_type
is
torch
.
Tensor
:
image
=
image
.
as_subclass
(
torch
.
Tensor
)
assert
is_simple_tensor
(
image
)
label
=
1
if
label_type
is
int
else
torch
.
tensor
([
1
])
if
dataset_return_type
is
dict
:
sample
=
{
"image"
:
image
,
"label"
:
label
,
}
else
:
sample
=
image
,
label
t
=
transforms
.
Compose
(
[
transforms
.
RandomResizedCrop
((
224
,
224
)),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandAugment
(),
transforms
.
TrivialAugmentWide
(),
transforms
.
AugMix
(),
transforms
.
AutoAugment
(),
to_tensor
(),
# TODO: ConvertImageDtype is a pass-through on PIL images, is that
# intended? This results in a failure if we convert to tensor after
# it, because the image would still be uint8 which make Normalize
# fail.
transforms
.
ConvertImageDtype
(
torch
.
float
),
transforms
.
Normalize
(
mean
=
[
0
,
0
,
0
],
std
=
[
1
,
1
,
1
]),
transforms
.
RandomErasing
(
p
=
1
),
]
)
out
=
t
(
sample
)
assert
type
(
out
)
==
type
(
sample
)
if
dataset_return_type
is
tuple
:
out_image
,
out_label
=
out
else
:
assert
out
.
keys
()
==
sample
.
keys
()
out_image
,
out_label
=
out
.
values
()
assert
out_image
.
shape
[
-
2
:]
==
(
224
,
224
)
assert
out_label
==
label
@
pytest
.
mark
.
parametrize
(
"image_type"
,
(
PIL
.
Image
,
torch
.
Tensor
,
datapoints
.
Image
))
@
pytest
.
mark
.
parametrize
(
"label_type"
,
(
torch
.
Tensor
,
list
))
@
pytest
.
mark
.
parametrize
(
"data_augmentation"
,
(
"hflip"
,
"lsj"
,
"multiscale"
,
"ssd"
,
"ssdlite"
))
@
pytest
.
mark
.
parametrize
(
"to_tensor"
,
(
transforms
.
ToTensor
,
transforms
.
ToImageTensor
))
def
test_detection_preset
(
image_type
,
label_type
,
data_augmentation
,
to_tensor
):
if
data_augmentation
==
"hflip"
:
t
=
[
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float
),
]
elif
data_augmentation
==
"lsj"
:
t
=
[
transforms
.
ScaleJitter
(
target_size
=
(
1024
,
1024
),
antialias
=
True
),
# Note: replaced FixedSizeCrop with RandomCrop, becuase we're
# leaving FixedSizeCrop in prototype for now, and it expects Label
# classes which we won't release yet.
# transforms.FixedSizeCrop(
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})
# ),
transforms
.
RandomCrop
((
1024
,
1024
),
pad_if_needed
=
True
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float
),
]
elif
data_augmentation
==
"multiscale"
:
t
=
[
transforms
.
RandomShortestSize
(
min_size
=
(
480
,
512
,
544
,
576
,
608
,
640
,
672
,
704
,
736
,
768
,
800
),
max_size
=
1333
,
antialias
=
True
),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float
),
]
elif
data_augmentation
==
"ssd"
:
t
=
[
transforms
.
RandomPhotometricDistort
(
p
=
1
),
transforms
.
RandomZoomOut
(
fill
=
defaultdict
(
lambda
:
(
123.0
,
117.0
,
104.0
),
{
datapoints
.
Mask
:
0
})),
# TODO: put back IoUCrop once we remove its hard requirement for Labels
# transforms.RandomIoUCrop(),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float
),
]
elif
data_augmentation
==
"ssdlite"
:
t
=
[
# TODO: put back IoUCrop once we remove its hard requirement for Labels
# transforms.RandomIoUCrop(),
transforms
.
RandomHorizontalFlip
(
p
=
1
),
to_tensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float
),
]
t
=
transforms
.
Compose
(
t
)
num_boxes
=
5
H
=
W
=
250
image
=
datapoints
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
H
,
W
),
dtype
=
torch
.
uint8
))
if
image_type
is
PIL
.
Image
:
image
=
to_pil_image
(
image
[
0
])
elif
image_type
is
torch
.
Tensor
:
image
=
image
.
as_subclass
(
torch
.
Tensor
)
assert
is_simple_tensor
(
image
)
label
=
torch
.
randint
(
0
,
10
,
size
=
(
num_boxes
,))
if
label_type
is
list
:
label
=
label
.
tolist
()
# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks
boxes
=
torch
.
randint
(
0
,
min
(
H
,
W
)
//
2
,
size
=
(
num_boxes
,
4
))
boxes
[:,
2
:]
+=
boxes
[:,
:
2
]
boxes
=
boxes
.
clamp
(
min
=
0
,
max
=
min
(
H
,
W
))
boxes
=
datapoints
.
BoundingBox
(
boxes
,
format
=
"XYXY"
,
spatial_size
=
(
H
,
W
))
masks
=
datapoints
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
num_boxes
,
H
,
W
),
dtype
=
torch
.
uint8
))
sample
=
{
"image"
:
image
,
"label"
:
label
,
"boxes"
:
boxes
,
"masks"
:
masks
,
}
out
=
t
(
sample
)
if
to_tensor
is
transforms
.
ToTensor
and
image_type
is
not
datapoints
.
Image
:
assert
is_simple_tensor
(
out
[
"image"
])
else
:
assert
isinstance
(
out
[
"image"
],
datapoints
.
Image
)
assert
isinstance
(
out
[
"label"
],
type
(
sample
[
"label"
]))
out
[
"label"
]
=
torch
.
tensor
(
out
[
"label"
])
assert
out
[
"boxes"
].
shape
[
0
]
==
out
[
"masks"
].
shape
[
0
]
==
out
[
"label"
].
shape
[
0
]
==
num_boxes
torchvision/prototype/transforms/_auto_augment.py
View file @
0aed8329
...
...
@@ -37,10 +37,11 @@ class _AutoAugmentBase(Transform):
unsupported_types
:
Tuple
[
Type
,
...]
=
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
)
->
Tuple
[
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
Union
[
datapoints
.
ImageType
,
datapoints
.
VideoType
]]:
flat_inputs
,
spec
=
tree_flatten
(
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
])
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
image_or_videos
=
[]
for
idx
,
inpt
in
enumerate
(
flat_inputs
):
if
check_type
(
for
idx
,
(
inpt
,
needs_transform
)
in
enumerate
(
zip
(
flat_inputs
,
needs_transform_list
)
):
if
needs_transform
and
check_type
(
inpt
,
(
datapoints
.
Image
,
...
...
torchvision/prototype/transforms/_color.py
View file @
0aed8329
...
...
@@ -169,7 +169,8 @@ class RandomPhotometricDistort(Transform):
if
isinstance
(
orig_inpt
,
PIL
.
Image
.
Image
):
inpt
=
F
.
pil_to_tensor
(
inpt
)
output
=
inpt
[...,
permutation
,
:,
:]
# TODO: Find a better fix than as_subclass???
output
=
inpt
[...,
permutation
,
:,
:].
as_subclass
(
type
(
inpt
))
if
isinstance
(
orig_inpt
,
PIL
.
Image
.
Image
):
output
=
F
.
to_image_pil
(
output
)
...
...
torchvision/prototype/transforms/_transform.py
View file @
0aed8329
...
...
@@ -36,8 +36,19 @@ class Transform(nn.Module):
self
.
_check_inputs
(
flat_inputs
)
params
=
self
.
_get_params
(
flat_inputs
)
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
params
=
self
.
_get_params
(
[
inpt
for
(
inpt
,
needs_transform
)
in
zip
(
flat_inputs
,
needs_transform_list
)
if
needs_transform
]
)
flat_outputs
=
[
self
.
_transform
(
inpt
,
params
)
if
needs_transform
else
inpt
for
(
inpt
,
needs_transform
)
in
zip
(
flat_inputs
,
needs_transform_list
)
]
return
tree_unflatten
(
flat_outputs
,
spec
)
def
_needs_transform_list
(
self
,
flat_inputs
:
List
[
Any
])
->
List
[
bool
]:
# Below is a heuristic on how to deal with simple tensor inputs:
# 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
...
...
@@ -53,7 +64,8 @@ class Transform(nn.Module):
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
flat_outputs
=
[]
needs_transform_list
=
[]
transform_simple_tensor
=
not
has_any
(
flat_inputs
,
datapoints
.
Image
,
datapoints
.
Video
,
PIL
.
Image
.
Image
)
for
inpt
in
flat_inputs
:
needs_transform
=
True
...
...
@@ -65,10 +77,8 @@ class Transform(nn.Module):
transform_simple_tensor
=
False
else
:
needs_transform
=
False
flat_outputs
.
append
(
self
.
_transform
(
inpt
,
params
)
if
needs_transform
else
inpt
)
return
tree_unflatten
(
flat_outputs
,
spec
)
needs_transform_list
.
append
(
needs_transform
)
return
needs_transform_list
def
extra_repr
(
self
)
->
str
:
extra
=
[]
...
...
@@ -159,10 +169,14 @@ class _RandomApplyTransform(Transform):
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
inputs
params
=
self
.
_get_params
(
flat_inputs
)
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
params
=
self
.
_get_params
(
[
inpt
for
(
inpt
,
needs_transform
)
in
zip
(
flat_inputs
,
needs_transform_list
)
if
needs_transform
]
)
flat_outputs
=
[
self
.
_transform
(
inpt
,
params
)
if
check_type
(
inpt
,
self
.
_transformed_types
)
else
inpt
for
inpt
in
flat_inputs
self
.
_transform
(
inpt
,
params
)
if
needs_transform
else
inpt
for
(
inpt
,
needs_transform
)
in
zip
(
flat_inputs
,
needs_transform_list
)
]
return
tree_unflatten
(
flat_outputs
,
spec
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment