Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
ddfee23d
Unverified
Commit
ddfee23d
authored
Oct 03, 2023
by
Philip Meier
Committed by
GitHub
Oct 03, 2023
Browse files
port tests for container transforms (#8012)
parent
0040fe7a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
112 additions
and
179 deletions
+112
-179
test/test_transforms_v2.py
test/test_transforms_v2.py
+0
-29
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+1
-137
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+102
-6
torchvision/transforms/v2/_container.py
torchvision/transforms/v2/_container.py
+9
-7
No files found.
test/test_transforms_v2.py
View file @
ddfee23d
...
...
@@ -122,35 +122,6 @@ class TestTransform:
t
(
inpt
)
class
TestContainers
:
@
pytest
.
mark
.
parametrize
(
"transform_cls"
,
[
transforms
.
Compose
,
transforms
.
RandomChoice
,
transforms
.
RandomOrder
])
def
test_assertions
(
self
,
transform_cls
):
with
pytest
.
raises
(
TypeError
,
match
=
"Argument transforms should be a sequence of callables"
):
transform_cls
(
transforms
.
RandomCrop
(
28
))
@
pytest
.
mark
.
parametrize
(
"transform_cls"
,
[
transforms
.
Compose
,
transforms
.
RandomChoice
,
transforms
.
RandomOrder
])
@
pytest
.
mark
.
parametrize
(
"trfms"
,
[
[
transforms
.
Pad
(
2
),
transforms
.
RandomCrop
(
28
)],
[
lambda
x
:
2.0
*
x
,
transforms
.
Pad
(
2
),
transforms
.
RandomCrop
(
28
)],
[
transforms
.
Pad
(
2
),
lambda
x
:
2.0
*
x
,
transforms
.
RandomCrop
(
28
)],
],
)
def
test_ctor
(
self
,
transform_cls
,
trfms
):
c
=
transform_cls
(
trfms
)
inpt
=
torch
.
rand
(
1
,
3
,
32
,
32
)
output
=
c
(
inpt
)
assert
isinstance
(
output
,
torch
.
Tensor
)
assert
output
.
ndim
==
4
class
TestRandomChoice
:
def
test_assertions
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"Length of p doesn't match the number of transforms"
):
transforms
.
RandomChoice
([
transforms
.
Pad
(
2
),
transforms
.
RandomCrop
(
28
)],
p
=
[
1
])
class
TestRandomIoUCrop
:
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"options"
,
[[
0.5
,
0.9
],
[
2.0
]])
...
...
test/test_transforms_v2_consistency.py
View file @
ddfee23d
...
...
@@ -11,9 +11,7 @@ import pytest
import
torch
import
torchvision.transforms.v2
as
v2_transforms
from
common_utils
import
assert_close
,
assert_equal
,
set_rng_seed
from
torch
import
nn
from
torchvision
import
transforms
as
legacy_transforms
,
tv_tensors
from
torchvision._utils
import
sequence_to_str
from
torchvision.transforms
import
functional
as
legacy_F
from
torchvision.transforms.v2
import
functional
as
prototype_F
...
...
@@ -71,63 +69,7 @@ class ConsistencyConfig:
LINEAR_TRANSFORMATION_MEAN
=
torch
.
rand
(
36
)
LINEAR_TRANSFORMATION_MATRIX
=
torch
.
rand
([
LINEAR_TRANSFORMATION_MEAN
.
numel
()]
*
2
)
CONSISTENCY_CONFIGS
=
[
ConsistencyConfig
(
v2_transforms
.
Compose
,
legacy_transforms
.
Compose
,
),
ConsistencyConfig
(
v2_transforms
.
RandomApply
,
legacy_transforms
.
RandomApply
,
),
ConsistencyConfig
(
v2_transforms
.
RandomChoice
,
legacy_transforms
.
RandomChoice
,
),
ConsistencyConfig
(
v2_transforms
.
RandomOrder
,
legacy_transforms
.
RandomOrder
,
),
]
@
pytest
.
mark
.
parametrize
(
"config"
,
CONSISTENCY_CONFIGS
,
ids
=
lambda
config
:
config
.
legacy_cls
.
__name__
)
def
test_signature_consistency
(
config
):
legacy_params
=
dict
(
inspect
.
signature
(
config
.
legacy_cls
).
parameters
)
prototype_params
=
dict
(
inspect
.
signature
(
config
.
prototype_cls
).
parameters
)
for
param
in
config
.
removed_params
:
legacy_params
.
pop
(
param
,
None
)
missing
=
legacy_params
.
keys
()
-
prototype_params
.
keys
()
if
missing
:
raise
AssertionError
(
f
"The prototype transform does not support the parameters "
f
"
{
sequence_to_str
(
sorted
(
missing
),
separate_last
=
'and '
)
}
, but the legacy transform does. "
f
"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
f
"the `ConsistencyConfig`."
)
extra
=
prototype_params
.
keys
()
-
legacy_params
.
keys
()
extra_without_default
=
{
param
for
param
in
extra
if
prototype_params
[
param
].
default
is
inspect
.
Parameter
.
empty
and
prototype_params
[
param
].
kind
not
in
{
inspect
.
Parameter
.
VAR_POSITIONAL
,
inspect
.
Parameter
.
VAR_KEYWORD
}
}
if
extra_without_default
:
raise
AssertionError
(
f
"The prototype transform requires the parameters "
f
"
{
sequence_to_str
(
sorted
(
extra_without_default
),
separate_last
=
'and '
)
}
, but the legacy transform does "
f
"not. Please add a default value."
)
legacy_signature
=
list
(
legacy_params
.
keys
())
# Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
# to the same number of parameters as the legacy one
prototype_signature
=
list
(
prototype_params
.
keys
())[:
len
(
legacy_signature
)]
assert
prototype_signature
==
legacy_signature
CONSISTENCY_CONFIGS
=
[]
def
check_call_consistency
(
...
...
@@ -288,84 +230,6 @@ def test_jit_consistency(config, args_kwargs):
assert_close
(
output_prototype_scripted
,
output_legacy_scripted
,
**
config
.
closeness_kwargs
)
class
TestContainerTransforms
:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
consistency automatically tests the wrapped transforms consistency.
Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
that were already tested for consistency above.
"""
def
test_compose
(
self
):
prototype_transform
=
v2_transforms
.
Compose
(
[
v2_transforms
.
Resize
(
256
),
v2_transforms
.
CenterCrop
(
224
),
]
)
legacy_transform
=
legacy_transforms
.
Compose
(
[
legacy_transforms
.
Resize
(
256
),
legacy_transforms
.
CenterCrop
(
224
),
]
)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency
(
prototype_transform
,
legacy_transform
,
closeness_kwargs
=
dict
(
rtol
=
0
,
atol
=
1
))
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0
,
0.1
,
0.5
,
0.9
,
1
])
@
pytest
.
mark
.
parametrize
(
"sequence_type"
,
[
list
,
nn
.
ModuleList
])
def
test_random_apply
(
self
,
p
,
sequence_type
):
prototype_transform
=
v2_transforms
.
RandomApply
(
sequence_type
(
[
v2_transforms
.
Resize
(
256
),
v2_transforms
.
CenterCrop
(
224
),
]
),
p
=
p
,
)
legacy_transform
=
legacy_transforms
.
RandomApply
(
sequence_type
(
[
legacy_transforms
.
Resize
(
256
),
legacy_transforms
.
CenterCrop
(
224
),
]
),
p
=
p
,
)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency
(
prototype_transform
,
legacy_transform
,
closeness_kwargs
=
dict
(
rtol
=
0
,
atol
=
1
))
if
sequence_type
is
nn
.
ModuleList
:
# quick and dirty test that it is jit-scriptable
scripted
=
torch
.
jit
.
script
(
prototype_transform
)
scripted
(
torch
.
rand
(
1
,
3
,
300
,
300
))
# We can't test other values for `p` since the random parameter generation is different
@
pytest
.
mark
.
parametrize
(
"probabilities"
,
[(
0
,
1
),
(
1
,
0
)])
def
test_random_choice
(
self
,
probabilities
):
prototype_transform
=
v2_transforms
.
RandomChoice
(
[
v2_transforms
.
Resize
(
256
),
legacy_transforms
.
CenterCrop
(
224
),
],
p
=
probabilities
,
)
legacy_transform
=
legacy_transforms
.
RandomChoice
(
[
legacy_transforms
.
Resize
(
256
),
legacy_transforms
.
CenterCrop
(
224
),
],
p
=
probabilities
,
)
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
check_call_consistency
(
prototype_transform
,
legacy_transform
,
closeness_kwargs
=
dict
(
rtol
=
0
,
atol
=
1
))
class
TestToTensorTransforms
:
def
test_pil_to_tensor
(
self
):
prototype_transform
=
v2_transforms
.
PILToTensor
()
...
...
test/test_transforms_v2_refactored.py
View file @
ddfee23d
...
...
@@ -396,6 +396,8 @@ def check_transform(transform, input, check_v1_compatibility=True, check_sample_
if
check_v1_compatibility
:
_check_transform_v1_compatibility
(
transform
,
input
,
**
_to_tolerances
(
check_v1_compatibility
))
return
output
def
transform_cls_to_functional
(
transform_cls
,
**
transform_specific_kwargs
):
def
wrapper
(
input
,
*
args
,
**
kwargs
):
...
...
@@ -1773,7 +1775,7 @@ class TestRotate:
transforms
.
RandomAffine
(
degrees
=
0
,
fill
=
"fill"
)
class
TestCo
mpose
:
class
TestCo
ntainerTransforms
:
class
BuiltinTransform
(
transforms
.
Transform
):
def
_transform
(
self
,
inpt
,
params
):
return
inpt
...
...
@@ -1788,7 +1790,10 @@ class TestCompose:
return
image
,
label
@
pytest
.
mark
.
parametrize
(
"transform_clss"
,
"transform_cls"
,
[
transforms
.
Compose
,
functools
.
partial
(
transforms
.
RandomApply
,
p
=
1
),
transforms
.
RandomOrder
]
)
@
pytest
.
mark
.
parametrize
(
"wrapped_transform_clss"
,
[
[
BuiltinTransform
],
[
PackedInputTransform
],
...
...
@@ -1803,12 +1808,12 @@ class TestCompose:
],
)
@
pytest
.
mark
.
parametrize
(
"unpack"
,
[
True
,
False
])
def
test_packed_unpacked
(
self
,
transform_clss
,
unpack
):
needs_packed_inputs
=
any
(
issubclass
(
cls
,
self
.
PackedInputTransform
)
for
cls
in
transform_clss
)
needs_unpacked_inputs
=
any
(
issubclass
(
cls
,
self
.
UnpackedInputTransform
)
for
cls
in
transform_clss
)
def
test_packed_unpacked
(
self
,
transform_cls
,
wrapped_
transform_clss
,
unpack
):
needs_packed_inputs
=
any
(
issubclass
(
cls
,
self
.
PackedInputTransform
)
for
cls
in
wrapped_
transform_clss
)
needs_unpacked_inputs
=
any
(
issubclass
(
cls
,
self
.
UnpackedInputTransform
)
for
cls
in
wrapped_
transform_clss
)
assert
not
(
needs_packed_inputs
and
needs_unpacked_inputs
)
transform
=
transform
s
.
Compose
([
cls
()
for
cls
in
transform_clss
])
transform
=
transform
_cls
([
cls
()
for
cls
in
wrapped_
transform_clss
])
image
=
make_image
()
label
=
3
...
...
@@ -1833,6 +1838,97 @@ class TestCompose:
assert
output
[
0
]
is
image
assert
output
[
1
]
is
label
def
test_compose
(
self
):
transform
=
transforms
.
Compose
(
[
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomVerticalFlip
(
p
=
1
),
]
)
input
=
make_image
()
actual
=
check_transform
(
transform
,
input
)
expected
=
F
.
vertical_flip
(
F
.
horizontal_flip
(
input
))
assert_equal
(
actual
,
expected
)
@
pytest
.
mark
.
parametrize
(
"p"
,
[
0.0
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"sequence_type"
,
[
list
,
nn
.
ModuleList
])
def
test_random_apply
(
self
,
p
,
sequence_type
):
transform
=
transforms
.
RandomApply
(
sequence_type
(
[
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomVerticalFlip
(
p
=
1
),
]
),
p
=
p
,
)
# This needs to be a pure tensor (or a PIL image), because otherwise check_transforms skips the v1 compatibility
# check
input
=
make_image_tensor
()
output
=
check_transform
(
transform
,
input
,
check_v1_compatibility
=
issubclass
(
sequence_type
,
nn
.
ModuleList
))
if
p
==
1
:
assert_equal
(
output
,
F
.
vertical_flip
(
F
.
horizontal_flip
(
input
)))
else
:
assert
output
is
input
@
pytest
.
mark
.
parametrize
(
"p"
,
[(
0
,
1
),
(
1
,
0
)])
def
test_random_choice
(
self
,
p
):
transform
=
transforms
.
RandomChoice
(
[
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomVerticalFlip
(
p
=
1
),
],
p
=
p
,
)
input
=
make_image
()
output
=
check_transform
(
transform
,
input
)
p_horz
,
p_vert
=
p
if
p_horz
:
assert_equal
(
output
,
F
.
horizontal_flip
(
input
))
else
:
assert_equal
(
output
,
F
.
vertical_flip
(
input
))
def
test_random_order
(
self
):
transform
=
transforms
.
Compose
(
[
transforms
.
RandomHorizontalFlip
(
p
=
1
),
transforms
.
RandomVerticalFlip
(
p
=
1
),
]
)
input
=
make_image
()
actual
=
check_transform
(
transform
,
input
)
# We can't really check whether the transforms are actually applied in random order. However, horizontal and
# vertical flip are commutative. Meaning, even under the assumption that the transform applies them in random
# order, we can use a fixed order to compute the expected value.
expected
=
F
.
vertical_flip
(
F
.
horizontal_flip
(
input
))
assert_equal
(
actual
,
expected
)
def
test_errors
(
self
):
for
cls
in
[
transforms
.
Compose
,
transforms
.
RandomChoice
,
transforms
.
RandomOrder
]:
with
pytest
.
raises
(
TypeError
,
match
=
"Argument transforms should be a sequence of callables"
):
cls
(
lambda
x
:
x
)
with
pytest
.
raises
(
ValueError
,
match
=
"at least one transform"
):
transforms
.
Compose
([])
for
p
in
[
-
1
,
2
]:
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"value in the interval [0.0, 1.0]"
)):
transforms
.
RandomApply
([
lambda
x
:
x
],
p
=
p
)
for
transforms_
,
p
in
[([
lambda
x
:
x
],
[]),
([],
[
1.0
])]:
with
pytest
.
raises
(
ValueError
,
match
=
"Length of p doesn't match the number of transforms"
):
transforms
.
RandomChoice
(
transforms_
,
p
=
p
)
class
TestToDtype
:
@
pytest
.
mark
.
parametrize
(
...
...
torchvision/transforms/v2/_container.py
View file @
ddfee23d
...
...
@@ -100,14 +100,15 @@ class RandomApply(Transform):
return
{
"transforms"
:
self
.
transforms
,
"p"
:
self
.
p
}
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
needs_unpacking
=
len
(
inputs
)
>
1
if
torch
.
rand
(
1
)
>=
self
.
p
:
return
sample
return
inputs
if
needs_unpacking
else
inputs
[
0
]
for
transform
in
self
.
transforms
:
sample
=
transform
(
sample
)
return
sample
outputs
=
transform
(
*
inputs
)
inputs
=
outputs
if
needs_unpacking
else
(
outputs
,)
return
outputs
def
extra_repr
(
self
)
->
str
:
format_string
=
[]
...
...
@@ -173,8 +174,9 @@ class RandomOrder(Transform):
self
.
transforms
=
transforms
def
forward
(
self
,
*
inputs
:
Any
)
->
Any
:
sample
=
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
]
needs_unpacking
=
len
(
inputs
)
>
1
for
idx
in
torch
.
randperm
(
len
(
self
.
transforms
)):
transform
=
self
.
transforms
[
idx
]
sample
=
transform
(
sample
)
return
sample
outputs
=
transform
(
*
inputs
)
inputs
=
outputs
if
needs_unpacking
else
(
outputs
,)
return
outputs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment