Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1120aa9e
Unverified
Commit
1120aa9e
authored
Feb 08, 2023
by
Philip Meier
Committed by
GitHub
Feb 08, 2023
Browse files
introduce heuristic for simple tensor handling of transforms v2 (#7170)
parent
1222b495
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
250 additions
and
101 deletions
+250
-101
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+200
-97
torchvision/prototype/transforms/_misc.py
torchvision/prototype/transforms/_misc.py
+19
-0
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+31
-4
No files found.
test/test_prototype_transforms.py
View file @
1120aa9e
import
itertools
import
itertools
import
re
import
numpy
as
np
import
numpy
as
np
import
PIL.Image
import
PIL.Image
import
pytest
import
pytest
import
torch
import
torch
import
torchvision.prototype.transforms.utils
import
torchvision.prototype.transforms.utils
from
common_utils
import
assert_equal
,
cpu_and_gpu
from
common_utils
import
cpu_and_gpu
from
prototype_common_utils
import
(
from
prototype_common_utils
import
(
assert_equal
,
DEFAULT_EXTRA_DIMS
,
DEFAULT_EXTRA_DIMS
,
make_bounding_box
,
make_bounding_box
,
make_bounding_boxes
,
make_bounding_boxes
,
...
@@ -25,7 +26,7 @@ from prototype_common_utils import (
...
@@ -25,7 +26,7 @@ from prototype_common_utils import (
)
)
from
torchvision.ops.boxes
import
box_iou
from
torchvision.ops.boxes
import
box_iou
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.prototype.transforms.utils
import
check_type
from
torchvision.prototype.transforms.utils
import
check_type
,
is_simple_tensor
from
torchvision.transforms.functional
import
InterpolationMode
,
pil_to_tensor
,
to_pil_image
from
torchvision.transforms.functional
import
InterpolationMode
,
pil_to_tensor
,
to_pil_image
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
...
@@ -222,6 +223,67 @@ class TestSmoke:
...
@@ -222,6 +223,67 @@ class TestSmoke:
transform
(
input
)
transform
(
input
)
@
pytest
.
mark
.
parametrize
(
"flat_inputs"
,
itertools
.
permutations
(
[
next
(
make_vanilla_tensor_images
()),
next
(
make_vanilla_tensor_images
()),
next
(
make_pil_images
()),
make_image
(),
next
(
make_videos
()),
],
3
,
),
)
def
test_simple_tensor_heuristic
(
flat_inputs
):
def
split_on_simple_tensor
(
to_split
):
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
# 1. The first simple tensor. If none is present, this will be `None`
# 2. A list of the remaining simple tensors
# 3. A list of all other items
simple_tensors
=
[]
others
=
[]
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
# affect the splitting.
for
item
,
inpt
in
zip
(
to_split
,
flat_inputs
):
(
simple_tensors
if
is_simple_tensor
(
inpt
)
else
others
).
append
(
item
)
return
simple_tensors
[
0
]
if
simple_tensors
else
None
,
simple_tensors
[
1
:],
others
class
CopyCloneTransform
(
transforms
.
Transform
):
def
_transform
(
self
,
inpt
,
params
):
return
inpt
.
clone
()
if
isinstance
(
inpt
,
torch
.
Tensor
)
else
inpt
.
copy
()
@
staticmethod
def
was_applied
(
output
,
inpt
):
identity
=
output
is
inpt
if
identity
:
return
False
# Make sure nothing fishy is going on
assert_equal
(
output
,
inpt
)
return
True
first_simple_tensor_input
,
other_simple_tensor_inputs
,
other_inputs
=
split_on_simple_tensor
(
flat_inputs
)
transform
=
CopyCloneTransform
()
transformed_sample
=
transform
(
flat_inputs
)
first_simple_tensor_output
,
other_simple_tensor_outputs
,
other_outputs
=
split_on_simple_tensor
(
transformed_sample
)
if
first_simple_tensor_input
is
not
None
:
if
other_inputs
:
assert
not
transform
.
was_applied
(
first_simple_tensor_output
,
first_simple_tensor_input
)
else
:
assert
transform
.
was_applied
(
first_simple_tensor_output
,
first_simple_tensor_input
)
for
output
,
inpt
in
zip
(
other_simple_tensor_outputs
,
other_simple_tensor_inputs
):
assert
not
transform
.
was_applied
(
output
,
inpt
)
for
input
,
output
in
zip
(
other_inputs
,
other_outputs
):
assert
transform
.
was_applied
(
output
,
input
)
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.0
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.0
,
1.0
])
class
TestRandomHorizontalFlip
:
class
TestRandomHorizontalFlip
:
def
input_expected_image_tensor
(
self
,
p
,
dtype
=
torch
.
float32
):
def
input_expected_image_tensor
(
self
,
p
,
dtype
=
torch
.
float32
):
...
@@ -1755,22 +1817,27 @@ class TestRandomResize:
...
@@ -1755,22 +1817,27 @@ class TestRandomResize:
)
)
@
pytest
.
mark
.
parametrize
(
class
TestToDtype
:
@
pytest
.
mark
.
parametrize
(
(
"dtype"
,
"expected_dtypes"
),
(
"dtype"
,
"expected_dtypes"
),
[
[
(
(
torch
.
float64
,
torch
.
float64
,
{
torch
.
Tensor
:
torch
.
float64
,
datapoints
.
Image
:
torch
.
float64
,
datapoints
.
BoundingBox
:
torch
.
float64
},
{
datapoints
.
Video
:
torch
.
float64
,
datapoints
.
Image
:
torch
.
float64
,
datapoints
.
BoundingBox
:
torch
.
float64
,
},
),
),
(
(
{
torch
.
Tensor
:
torch
.
int32
,
datapoints
.
Image
:
torch
.
float32
,
datapoints
.
BoundingBox
:
torch
.
float64
},
{
datapoints
.
Video
:
torch
.
int32
,
datapoints
.
Image
:
torch
.
float32
,
datapoints
.
BoundingBox
:
torch
.
float64
},
{
torch
.
Tensor
:
torch
.
int32
,
datapoints
.
Image
:
torch
.
float32
,
datapoints
.
BoundingBox
:
torch
.
float64
},
{
datapoints
.
Video
:
torch
.
int32
,
datapoints
.
Image
:
torch
.
float32
,
datapoints
.
BoundingBox
:
torch
.
float64
},
),
),
],
],
)
)
def
test_
to_dtype
(
dtype
,
expected_dtypes
):
def
test_
call
(
self
,
dtype
,
expected_dtypes
):
sample
=
dict
(
sample
=
dict
(
plain_tensor
=
torch
.
testing
.
make_tensor
(
5
,
dtype
=
torch
.
int64
,
device
=
"cpu"
),
video
=
make_video
(
dtype
=
torch
.
int64
),
image
=
make_image
(
dtype
=
torch
.
uint8
),
image
=
make_image
(
dtype
=
torch
.
uint8
),
bounding_box
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
dtype
=
torch
.
float32
),
bounding_box
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
dtype
=
torch
.
float32
),
str
=
"str"
,
str
=
"str"
,
...
@@ -1792,23 +1859,35 @@ def test_to_dtype(dtype, expected_dtypes):
...
@@ -1792,23 +1859,35 @@ def test_to_dtype(dtype, expected_dtypes):
else
:
else
:
assert
transformed_value
is
value
assert
transformed_value
is
value
@
pytest
.
mark
.
filterwarnings
(
"error"
)
def
test_plain_tensor_call
(
self
):
tensor
=
torch
.
empty
((),
dtype
=
torch
.
float32
)
transform
=
transforms
.
ToDtype
({
torch
.
Tensor
:
torch
.
float64
})
@
pytest
.
mark
.
parametrize
(
assert
transform
(
tensor
).
dtype
is
torch
.
float64
@
pytest
.
mark
.
parametrize
(
"other_type"
,
[
datapoints
.
Image
,
datapoints
.
Video
])
def
test_plain_tensor_warning
(
self
,
other_type
):
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"`torch.Tensor` will *not* be transformed"
)):
transforms
.
ToDtype
(
dtype
=
{
torch
.
Tensor
:
torch
.
float32
,
other_type
:
torch
.
float64
})
class
TestPermuteDimensions
:
@
pytest
.
mark
.
parametrize
(
(
"dims"
,
"inverse_dims"
),
(
"dims"
,
"inverse_dims"
),
[
[
(
(
{
torch
.
Tensor
:
(
1
,
2
,
0
),
datapoints
.
Image
:
(
2
,
1
,
0
),
datapoints
.
Video
:
None
},
{
datapoints
.
Image
:
(
2
,
1
,
0
),
datapoints
.
Video
:
None
},
{
torch
.
Tensor
:
(
2
,
0
,
1
),
datapoints
.
Image
:
(
2
,
1
,
0
),
datapoints
.
Video
:
None
},
{
datapoints
.
Image
:
(
2
,
1
,
0
),
datapoints
.
Video
:
None
},
),
),
(
(
{
torch
.
Tensor
:
(
1
,
2
,
0
),
datapoints
.
Image
:
(
2
,
1
,
0
),
datapoints
.
Video
:
(
1
,
2
,
3
,
0
)},
{
datapoints
.
Image
:
(
2
,
1
,
0
),
datapoints
.
Video
:
(
1
,
2
,
3
,
0
)},
{
torch
.
Tensor
:
(
2
,
0
,
1
),
datapoints
.
Image
:
(
2
,
1
,
0
),
datapoints
.
Video
:
(
3
,
0
,
1
,
2
)},
{
datapoints
.
Image
:
(
2
,
1
,
0
),
datapoints
.
Video
:
(
3
,
0
,
1
,
2
)},
),
),
],
],
)
)
def
test_
permute_dimensions
(
dims
,
inverse_dims
):
def
test_
call
(
self
,
dims
,
inverse_dims
):
sample
=
dict
(
sample
=
dict
(
plain_tensor
=
torch
.
testing
.
make_tensor
((
3
,
28
,
28
),
dtype
=
torch
.
uint8
,
device
=
"cpu"
),
image
=
make_image
(),
image
=
make_image
(),
bounding_box
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
),
bounding_box
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
),
video
=
make_video
(),
video
=
make_video
(),
...
@@ -1832,17 +1911,29 @@ def test_permute_dimensions(dims, inverse_dims):
...
@@ -1832,17 +1911,29 @@ def test_permute_dimensions(dims, inverse_dims):
else
:
else
:
assert
transformed_value
is
value
assert
transformed_value
is
value
@
pytest
.
mark
.
filterwarnings
(
"error"
)
def
test_plain_tensor_call
(
self
):
tensor
=
torch
.
empty
((
2
,
3
,
4
))
transform
=
transforms
.
PermuteDimensions
(
dims
=
(
1
,
2
,
0
))
@
pytest
.
mark
.
parametrize
(
assert
transform
(
tensor
).
shape
==
(
3
,
4
,
2
)
@
pytest
.
mark
.
parametrize
(
"other_type"
,
[
datapoints
.
Image
,
datapoints
.
Video
])
def
test_plain_tensor_warning
(
self
,
other_type
):
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"`torch.Tensor` will *not* be transformed"
)):
transforms
.
PermuteDimensions
(
dims
=
{
torch
.
Tensor
:
(
0
,
1
),
other_type
:
(
1
,
0
)})
class
TestTransposeDimensions
:
@
pytest
.
mark
.
parametrize
(
"dims"
,
"dims"
,
[
[
(
-
1
,
-
2
),
(
-
1
,
-
2
),
{
torch
.
Tensor
:
(
-
1
,
-
2
),
datapoints
.
Image
:
(
1
,
2
),
datapoints
.
Video
:
None
},
{
datapoints
.
Image
:
(
1
,
2
),
datapoints
.
Video
:
None
},
],
],
)
)
def
test_
transpose_dimensions
(
dims
):
def
test_
call
(
self
,
dims
):
sample
=
dict
(
sample
=
dict
(
plain_tensor
=
torch
.
testing
.
make_tensor
((
3
,
28
,
28
),
dtype
=
torch
.
uint8
,
device
=
"cpu"
),
image
=
make_image
(),
image
=
make_image
(),
bounding_box
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
),
bounding_box
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
),
video
=
make_video
(),
video
=
make_video
(),
...
@@ -1867,6 +1958,18 @@ def test_transpose_dimensions(dims):
...
@@ -1867,6 +1958,18 @@ def test_transpose_dimensions(dims):
else
:
else
:
assert
transformed_value
is
value
assert
transformed_value
is
value
@
pytest
.
mark
.
filterwarnings
(
"error"
)
def
test_plain_tensor_call
(
self
):
tensor
=
torch
.
empty
((
2
,
3
,
4
))
transform
=
transforms
.
TransposeDimensions
(
dims
=
(
0
,
2
))
assert
transform
(
tensor
).
shape
==
(
4
,
3
,
2
)
@
pytest
.
mark
.
parametrize
(
"other_type"
,
[
datapoints
.
Image
,
datapoints
.
Video
])
def
test_plain_tensor_warning
(
self
,
other_type
):
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"`torch.Tensor` will *not* be transformed"
)):
transforms
.
TransposeDimensions
(
dims
=
{
torch
.
Tensor
:
(
0
,
1
),
other_type
:
(
1
,
0
)})
class
TestUniformTemporalSubsample
:
class
TestUniformTemporalSubsample
:
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
torchvision/prototype/transforms/_misc.py
View file @
1120aa9e
import
warnings
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
import
PIL.Image
import
PIL.Image
...
@@ -155,6 +156,12 @@ class ToDtype(Transform):
...
@@ -155,6 +156,12 @@ class ToDtype(Transform):
super
().
__init__
()
super
().
__init__
()
if
not
isinstance
(
dtype
,
dict
):
if
not
isinstance
(
dtype
,
dict
):
dtype
=
_get_defaultdict
(
dtype
)
dtype
=
_get_defaultdict
(
dtype
)
if
torch
.
Tensor
in
dtype
and
any
(
cls
in
dtype
for
cls
in
[
datapoints
.
Image
,
datapoints
.
Video
]):
warnings
.
warn
(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self
.
dtype
=
dtype
self
.
dtype
=
dtype
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
...
@@ -171,6 +178,12 @@ class PermuteDimensions(Transform):
...
@@ -171,6 +178,12 @@ class PermuteDimensions(Transform):
super
().
__init__
()
super
().
__init__
()
if
not
isinstance
(
dims
,
dict
):
if
not
isinstance
(
dims
,
dict
):
dims
=
_get_defaultdict
(
dims
)
dims
=
_get_defaultdict
(
dims
)
if
torch
.
Tensor
in
dims
and
any
(
cls
in
dims
for
cls
in
[
datapoints
.
Image
,
datapoints
.
Video
]):
warnings
.
warn
(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self
.
dims
=
dims
self
.
dims
=
dims
def
_transform
(
def
_transform
(
...
@@ -189,6 +202,12 @@ class TransposeDimensions(Transform):
...
@@ -189,6 +202,12 @@ class TransposeDimensions(Transform):
super
().
__init__
()
super
().
__init__
()
if
not
isinstance
(
dims
,
dict
):
if
not
isinstance
(
dims
,
dict
):
dims
=
_get_defaultdict
(
dims
)
dims
=
_get_defaultdict
(
dims
)
if
torch
.
Tensor
in
dims
and
any
(
cls
in
dims
for
cls
in
[
datapoints
.
Image
,
datapoints
.
Video
]):
warnings
.
warn
(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self
.
dims
=
dims
self
.
dims
=
dims
def
_transform
(
def
_transform
(
...
...
torchvision/prototype/transforms/_transform.py
View file @
1120aa9e
...
@@ -7,7 +7,8 @@ import PIL.Image
...
@@ -7,7 +7,8 @@ import PIL.Image
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision.prototype.transforms.utils
import
check_type
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms.utils
import
check_type
,
has_any
,
is_simple_tensor
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
...
@@ -37,9 +38,35 @@ class Transform(nn.Module):
...
@@ -37,9 +38,35 @@ class Transform(nn.Module):
params
=
self
.
_get_params
(
flat_inputs
)
params
=
self
.
_get_params
(
flat_inputs
)
flat_outputs
=
[
# Below is a heuristic on how to deal with simple tensor inputs:
self
.
_transform
(
inpt
,
params
)
if
check_type
(
inpt
,
self
.
_transformed_types
)
else
inpt
for
inpt
in
flat_inputs
# 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
]
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input.
#
# This heuristic stems from two requirements:
# 1. We need to keep BC for single input simple tensors and treat them as images.
# 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface`
# return supplemental numerical data as tensors that cannot be transformed as images.
#
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
flat_outputs
=
[]
transform_simple_tensor
=
not
has_any
(
flat_inputs
,
datapoints
.
Image
,
datapoints
.
Video
,
PIL
.
Image
.
Image
)
for
inpt
in
flat_inputs
:
needs_transform
=
True
if
not
check_type
(
inpt
,
self
.
_transformed_types
):
needs_transform
=
False
elif
is_simple_tensor
(
inpt
):
if
transform_simple_tensor
:
transform_simple_tensor
=
False
else
:
needs_transform
=
False
flat_outputs
.
append
(
self
.
_transform
(
inpt
,
params
)
if
needs_transform
else
inpt
)
return
tree_unflatten
(
flat_outputs
,
spec
)
return
tree_unflatten
(
flat_outputs
,
spec
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment