Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1f94320d
Unverified
Commit
1f94320d
authored
Sep 04, 2023
by
Philip Meier
Committed by
GitHub
Sep 04, 2023
Browse files
port AA tests (#7927)
Co-authored-by:
Nicolas Hug
<
contact@nicolas-hug.com
>
parent
d0e16b76
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
114 additions
and
278 deletions
+114
-278
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+0
-275
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+114
-3
No files found.
test/test_transforms_v2_consistency.py
View file @
1f94320d
...
...
@@ -705,281 +705,6 @@ class TestToTensorTransforms:
assert_equal
(
prototype_transform
(
image_numpy
),
legacy_transform
(
image_numpy
))
class
TestAATransforms
:
@
pytest
.
mark
.
parametrize
(
"inpt"
,
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
tv_tensors
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
PIL
.
Image
.
NEAREST
,
],
)
def
test_randaug
(
self
,
inpt
,
interpolation
,
mocker
):
t_ref
=
legacy_transforms
.
RandAugment
(
interpolation
=
interpolation
,
num_ops
=
1
)
t
=
v2_transforms
.
RandAugment
(
interpolation
=
interpolation
,
num_ops
=
1
)
le
=
len
(
t
.
_AUGMENTATION_SPACE
)
keys
=
list
(
t
.
_AUGMENTATION_SPACE
.
keys
())
randint_values
=
[]
for
i
in
range
(
le
):
# Stable API, op_index random call
randint_values
.
append
(
i
)
# Stable API, if signed there is another random call
if
t
.
_AUGMENTATION_SPACE
[
keys
[
i
]][
1
]:
randint_values
.
append
(
0
)
# New API, _get_random_item
randint_values
.
append
(
i
)
randint_values
=
iter
(
randint_values
)
mocker
.
patch
(
"torch.randint"
,
side_effect
=
lambda
*
arg
,
**
kwargs
:
torch
.
tensor
(
next
(
randint_values
)))
mocker
.
patch
(
"torch.rand"
,
return_value
=
1.0
)
for
i
in
range
(
le
):
expected_output
=
t_ref
(
inpt
)
output
=
t
(
inpt
)
assert_close
(
expected_output
,
output
,
atol
=
1
,
rtol
=
0.1
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
],
)
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
None
,
85
,
(
10
,
-
10
,
10
),
0.7
,
[
0.0
,
0.0
,
0.0
],
[
1
],
1
])
def
test_randaug_jit
(
self
,
interpolation
,
fill
):
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
t_ref
=
legacy_transforms
.
RandAugment
(
interpolation
=
interpolation
,
num_ops
=
1
,
fill
=
fill
)
t
=
v2_transforms
.
RandAugment
(
interpolation
=
interpolation
,
num_ops
=
1
,
fill
=
fill
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt
=
torch
.
jit
.
script
(
t
)
torch
.
manual_seed
(
12
)
expected_output
=
tt_ref
(
inpt
)
torch
.
manual_seed
(
12
)
scripted_output
=
tt
(
inpt
)
assert_equal
(
scripted_output
,
expected_output
)
@
pytest
.
mark
.
parametrize
(
"inpt"
,
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
tv_tensors
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
PIL
.
Image
.
NEAREST
,
],
)
def
test_trivial_aug
(
self
,
inpt
,
interpolation
,
mocker
):
t_ref
=
legacy_transforms
.
TrivialAugmentWide
(
interpolation
=
interpolation
)
t
=
v2_transforms
.
TrivialAugmentWide
(
interpolation
=
interpolation
)
le
=
len
(
t
.
_AUGMENTATION_SPACE
)
keys
=
list
(
t
.
_AUGMENTATION_SPACE
.
keys
())
randint_values
=
[]
for
i
in
range
(
le
):
# Stable API, op_index random call
randint_values
.
append
(
i
)
key
=
keys
[
i
]
# Stable API, random magnitude
aug_op
=
t
.
_AUGMENTATION_SPACE
[
key
]
magnitudes
=
aug_op
[
0
](
2
,
0
,
0
)
if
magnitudes
is
not
None
:
randint_values
.
append
(
5
)
# Stable API, if signed there is another random call
if
aug_op
[
1
]:
randint_values
.
append
(
0
)
# New API, _get_random_item
randint_values
.
append
(
i
)
# New API, random magnitude
if
magnitudes
is
not
None
:
randint_values
.
append
(
5
)
randint_values
=
iter
(
randint_values
)
mocker
.
patch
(
"torch.randint"
,
side_effect
=
lambda
*
arg
,
**
kwargs
:
torch
.
tensor
(
next
(
randint_values
)))
mocker
.
patch
(
"torch.rand"
,
return_value
=
1.0
)
for
_
in
range
(
le
):
expected_output
=
t_ref
(
inpt
)
output
=
t
(
inpt
)
assert_close
(
expected_output
,
output
,
atol
=
1
,
rtol
=
0.1
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
],
)
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
None
,
85
,
(
10
,
-
10
,
10
),
0.7
,
[
0.0
,
0.0
,
0.0
],
[
1
],
1
])
def
test_trivial_aug_jit
(
self
,
interpolation
,
fill
):
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
t_ref
=
legacy_transforms
.
TrivialAugmentWide
(
interpolation
=
interpolation
,
fill
=
fill
)
t
=
v2_transforms
.
TrivialAugmentWide
(
interpolation
=
interpolation
,
fill
=
fill
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt
=
torch
.
jit
.
script
(
t
)
torch
.
manual_seed
(
12
)
expected_output
=
tt_ref
(
inpt
)
torch
.
manual_seed
(
12
)
scripted_output
=
tt
(
inpt
)
assert_equal
(
scripted_output
,
expected_output
)
@
pytest
.
mark
.
parametrize
(
"inpt"
,
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
tv_tensors
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
PIL
.
Image
.
NEAREST
,
],
)
def
test_augmix
(
self
,
inpt
,
interpolation
,
mocker
):
t_ref
=
legacy_transforms
.
AugMix
(
interpolation
=
interpolation
,
mixture_width
=
1
,
chain_depth
=
1
)
t_ref
.
_sample_dirichlet
=
lambda
t
:
t
.
softmax
(
dim
=-
1
)
t
=
v2_transforms
.
AugMix
(
interpolation
=
interpolation
,
mixture_width
=
1
,
chain_depth
=
1
)
t
.
_sample_dirichlet
=
lambda
t
:
t
.
softmax
(
dim
=-
1
)
le
=
len
(
t
.
_AUGMENTATION_SPACE
)
keys
=
list
(
t
.
_AUGMENTATION_SPACE
.
keys
())
randint_values
=
[]
for
i
in
range
(
le
):
# Stable API, op_index random call
randint_values
.
append
(
i
)
key
=
keys
[
i
]
# Stable API, random magnitude
aug_op
=
t
.
_AUGMENTATION_SPACE
[
key
]
magnitudes
=
aug_op
[
0
](
2
,
0
,
0
)
if
magnitudes
is
not
None
:
randint_values
.
append
(
5
)
# Stable API, if signed there is another random call
if
aug_op
[
1
]:
randint_values
.
append
(
0
)
# New API, _get_random_item
randint_values
.
append
(
i
)
# New API, random magnitude
if
magnitudes
is
not
None
:
randint_values
.
append
(
5
)
randint_values
=
iter
(
randint_values
)
mocker
.
patch
(
"torch.randint"
,
side_effect
=
lambda
*
arg
,
**
kwargs
:
torch
.
tensor
(
next
(
randint_values
)))
mocker
.
patch
(
"torch.rand"
,
return_value
=
1.0
)
expected_output
=
t_ref
(
inpt
)
output
=
t
(
inpt
)
assert_equal
(
expected_output
,
output
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
],
)
@
pytest
.
mark
.
parametrize
(
"fill"
,
[
None
,
85
,
(
10
,
-
10
,
10
),
0.7
,
[
0.0
,
0.0
,
0.0
],
[
1
],
1
])
def
test_augmix_jit
(
self
,
interpolation
,
fill
):
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
t_ref
=
legacy_transforms
.
AugMix
(
interpolation
=
interpolation
,
mixture_width
=
1
,
chain_depth
=
1
,
fill
=
fill
)
t
=
v2_transforms
.
AugMix
(
interpolation
=
interpolation
,
mixture_width
=
1
,
chain_depth
=
1
,
fill
=
fill
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt
=
torch
.
jit
.
script
(
t
)
torch
.
manual_seed
(
12
)
expected_output
=
tt_ref
(
inpt
)
torch
.
manual_seed
(
12
)
scripted_output
=
tt
(
inpt
)
assert_equal
(
scripted_output
,
expected_output
)
@
pytest
.
mark
.
parametrize
(
"inpt"
,
[
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
),
PIL
.
Image
.
new
(
"RGB"
,
(
256
,
256
),
123
),
tv_tensors
.
Image
(
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)),
],
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
PIL
.
Image
.
NEAREST
,
],
)
def
test_aa
(
self
,
inpt
,
interpolation
):
aa_policy
=
legacy_transforms
.
AutoAugmentPolicy
(
"imagenet"
)
t_ref
=
legacy_transforms
.
AutoAugment
(
aa_policy
,
interpolation
=
interpolation
)
t
=
v2_transforms
.
AutoAugment
(
aa_policy
,
interpolation
=
interpolation
)
torch
.
manual_seed
(
12
)
expected_output
=
t_ref
(
inpt
)
torch
.
manual_seed
(
12
)
output
=
t
(
inpt
)
assert_equal
(
expected_output
,
output
)
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
v2_transforms
.
InterpolationMode
.
NEAREST
,
v2_transforms
.
InterpolationMode
.
BILINEAR
,
],
)
def
test_aa_jit
(
self
,
interpolation
):
inpt
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
256
,
256
),
dtype
=
torch
.
uint8
)
aa_policy
=
legacy_transforms
.
AutoAugmentPolicy
(
"imagenet"
)
t_ref
=
legacy_transforms
.
AutoAugment
(
aa_policy
,
interpolation
=
interpolation
)
t
=
v2_transforms
.
AutoAugment
(
aa_policy
,
interpolation
=
interpolation
)
tt_ref
=
torch
.
jit
.
script
(
t_ref
)
tt
=
torch
.
jit
.
script
(
t
)
torch
.
manual_seed
(
12
)
expected_output
=
tt_ref
(
inpt
)
torch
.
manual_seed
(
12
)
scripted_output
=
tt
(
inpt
)
assert_equal
(
scripted_output
,
expected_output
)
def
import_transforms_from_references
(
reference
):
HERE
=
Path
(
__file__
).
parent
PROJECT_ROOT
=
HERE
.
parent
...
...
test/test_transforms_v2_refactored.py
View file @
1f94320d
...
...
@@ -232,7 +232,7 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol):
"""If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
``get_params`` method that is the v1 equivalent, the output is close to v1, is scriptable, and the scripted version
can be called without error."""
if
type
(
input
)
is
not
torch
.
Tensor
or
isinstance
(
input
,
PIL
.
Image
.
Image
):
if
not
(
type
(
input
)
is
torch
.
Tensor
or
isinstance
(
input
,
PIL
.
Image
.
Image
)
)
:
return
v1_transform_cls
=
transform
.
_v1_transform_cls
...
...
@@ -250,7 +250,7 @@ def _check_transform_v1_compatibility(transform, input, *, rtol, atol):
with
freeze_rng_state
():
output_v1
=
v1_transform
(
input
)
assert_close
(
output_v2
,
output_v1
,
rtol
=
rtol
,
atol
=
atol
)
assert_close
(
F
.
to_image
(
output_v2
)
,
F
.
to_image
(
output_v1
)
,
rtol
=
rtol
,
atol
=
atol
)
if
isinstance
(
input
,
PIL
.
Image
.
Image
):
return
...
...
@@ -2772,7 +2772,10 @@ class TestErase:
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_transform
(
self
,
make_input
,
device
):
check_transform
(
transforms
.
RandomErasing
(
p
=
1
),
make_input
(
device
=
device
))
input
=
make_input
(
device
=
device
)
check_transform
(
transforms
.
RandomErasing
(
p
=
1
),
input
,
check_v1_compatibility
=
not
isinstance
(
input
,
PIL
.
Image
.
Image
)
)
def
_reference_erase_image
(
self
,
image
,
*
,
i
,
j
,
h
,
w
,
v
):
mask
=
torch
.
zeros_like
(
image
,
dtype
=
torch
.
bool
)
...
...
@@ -2898,3 +2901,111 @@ class TestGaussianBlur:
else
:
assert
sigma
[
0
]
<=
params
[
"sigma"
][
0
]
<=
sigma
[
1
]
assert
sigma
[
0
]
<=
params
[
"sigma"
][
1
]
<=
sigma
[
1
]
class
TestAutoAugmentTransforms
:
# These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling.
# It's typically very hard to test the effect on some parameters without heavy mocking logic.
# This class adds correctness tests for the kernels that are specific to those transforms. The rest of kernels, e.g.
# rotate, are tested in their respective classes. The rest of the tests here are mostly smoke tests.
def
_reference_shear_translate
(
self
,
image
,
*
,
transform_id
,
magnitude
,
interpolation
,
fill
):
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
input
=
image
else
:
input
=
F
.
to_pil_image
(
image
)
matrix
=
{
"ShearX"
:
(
1
,
magnitude
,
0
,
0
,
1
,
0
),
"ShearY"
:
(
1
,
0
,
0
,
magnitude
,
1
,
0
),
"TranslateX"
:
(
1
,
0
,
-
int
(
magnitude
),
0
,
1
,
0
),
"TranslateY"
:
(
1
,
0
,
0
,
0
,
1
,
-
int
(
magnitude
)),
}[
transform_id
]
output
=
input
.
transform
(
input
.
size
,
PIL
.
Image
.
AFFINE
,
matrix
,
resample
=
pil_modes_mapping
[
interpolation
],
fill
=
fill
)
if
isinstance
(
image
,
PIL
.
Image
.
Image
):
return
output
else
:
return
F
.
to_image
(
output
)
@
pytest
.
mark
.
parametrize
(
"transform_id"
,
[
"ShearX"
,
"ShearY"
,
"TranslateX"
,
"TranslateY"
])
@
pytest
.
mark
.
parametrize
(
"magnitude"
,
[
0.3
,
-
0.2
,
0.0
])
@
pytest
.
mark
.
parametrize
(
"interpolation"
,
[
transforms
.
InterpolationMode
.
NEAREST
,
transforms
.
InterpolationMode
.
BILINEAR
]
)
@
pytest
.
mark
.
parametrize
(
"fill"
,
CORRECTNESS_FILLS
)
@
pytest
.
mark
.
parametrize
(
"input_type"
,
[
"Tensor"
,
"PIL"
])
def
test_correctness_shear_translate
(
self
,
transform_id
,
magnitude
,
interpolation
,
fill
,
input_type
):
# ShearX/Y and TranslateX/Y are the only ops that are native to the AA transforms. They are modeled after the
# reference implementation:
# https://github.com/tensorflow/models/blob/885fda091c46c59d6c7bb5c7e760935eacc229da/research/autoaugment/augmentation_transforms.py#L273-L362
# All other ops are checked in their respective dedicated tests.
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
if
input_type
==
"PIL"
:
image
=
F
.
to_pil_image
(
image
)
if
"Translate"
in
transform_id
:
# For TranslateX/Y magnitude is a value in pixels
magnitude
*=
min
(
F
.
get_size
(
image
))
actual
=
transforms
.
AutoAugment
().
_apply_image_or_video_transform
(
image
,
transform_id
=
transform_id
,
magnitude
=
magnitude
,
interpolation
=
interpolation
,
fill
=
{
type
(
image
):
fill
},
)
expected
=
self
.
_reference_shear_translate
(
image
,
transform_id
=
transform_id
,
magnitude
=
magnitude
,
interpolation
=
interpolation
,
fill
=
fill
)
if
input_type
==
"PIL"
:
actual
,
expected
=
F
.
to_image
(
actual
),
F
.
to_image
(
expected
)
if
"Shear"
in
transform_id
and
input_type
==
"Tensor"
:
mae
=
(
actual
.
float
()
-
expected
.
float
()).
abs
().
mean
()
assert
mae
<
(
12
if
interpolation
is
transforms
.
InterpolationMode
.
NEAREST
else
5
)
else
:
assert_close
(
actual
,
expected
,
rtol
=
0
,
atol
=
1
)
@
pytest
.
mark
.
parametrize
(
"transform"
,
[
transforms
.
AutoAugment
(),
transforms
.
RandAugment
(),
transforms
.
TrivialAugmentWide
(),
transforms
.
AugMix
()],
)
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_video
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
uint8
,
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_transform_smoke
(
self
,
transform
,
make_input
,
dtype
,
device
):
if
make_input
is
make_image_pil
and
not
(
dtype
is
torch
.
uint8
and
device
==
"cpu"
):
pytest
.
skip
(
"PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
"will degenerate to that anyway."
)
input
=
make_input
(
dtype
=
dtype
,
device
=
device
)
with
freeze_rng_state
():
# By default every test starts from the same random seed. This leads to minimal coverage of the sampling
# that happens inside forward(). To avoid calling the transform multiple times to achieve higher coverage,
# we build a reproducible random seed from the input type, dtype, and device.
torch
.
manual_seed
(
hash
((
make_input
,
dtype
,
device
)))
# For v2, we changed the random sampling of the AA transforms. This makes it impossible to compare the v1
# and v2 outputs without complicated mocking and monkeypatching. Thus, we skip the v1 compatibility checks
# here and only check if we can script the v2 transform and subsequently call the result.
check_transform
(
transform
,
input
,
check_v1_compatibility
=
False
)
if
type
(
input
)
is
torch
.
Tensor
and
dtype
is
torch
.
uint8
:
_script
(
transform
)(
input
)
def
test_auto_augment_policy_error
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"provided policy"
):
transforms
.
AutoAugment
(
policy
=
None
)
@
pytest
.
mark
.
parametrize
(
"severity"
,
[
0
,
11
])
def
test_aug_mix_severity_error
(
self
,
severity
):
with
pytest
.
raises
(
ValueError
,
match
=
"severity must be between"
):
transforms
.
AugMix
(
severity
=
severity
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment