Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b6574c92
Unverified
Commit
b6574c92
authored
Sep 18, 2023
by
Philip Meier
Committed by
GitHub
Sep 18, 2023
Browse files
port tests for transforms.ColorJitter (#7968)
parent
5fa8050d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
60 deletions
+64
-60
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+0
-60
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+64
-0
No files found.
test/test_transforms_v2_consistency.py
View file @
b6574c92
...
...
@@ -228,23 +228,6 @@ CONSISTENCY_CONFIGS = [
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs
=
dict
(
rtol
=
None
,
atol
=
None
),
),
ConsistencyConfig
(
v2_transforms
.
ColorJitter
,
legacy_transforms
.
ColorJitter
,
[
ArgsKwargs
(),
ArgsKwargs
(
brightness
=
0.1
),
ArgsKwargs
(
brightness
=
(
0.2
,
0.3
)),
ArgsKwargs
(
contrast
=
0.4
),
ArgsKwargs
(
contrast
=
(
0.5
,
0.6
)),
ArgsKwargs
(
saturation
=
0.7
),
ArgsKwargs
(
saturation
=
(
0.8
,
0.9
)),
ArgsKwargs
(
hue
=
0.3
),
ArgsKwargs
(
hue
=
(
-
0.1
,
0.2
)),
ArgsKwargs
(
brightness
=
0.1
,
contrast
=
0.4
,
saturation
=
0.5
,
hue
=
0.3
),
],
closeness_kwargs
=
{
"atol"
:
1e-5
,
"rtol"
:
1e-5
},
),
ConsistencyConfig
(
v2_transforms
.
PILToTensor
,
legacy_transforms
.
PILToTensor
,
...
...
@@ -453,49 +436,6 @@ def test_call_consistency(config, args_kwargs):
)
get_params_parametrization
=
pytest
.
mark
.
parametrize
(
(
"config"
,
"get_params_args_kwargs"
),
[
pytest
.
param
(
next
(
config
for
config
in
CONSISTENCY_CONFIGS
if
config
.
prototype_cls
is
transform_cls
),
get_params_args_kwargs
,
id
=
transform_cls
.
__name__
,
)
for
transform_cls
,
get_params_args_kwargs
in
[
(
v2_transforms
.
ColorJitter
,
ArgsKwargs
(
brightness
=
None
,
contrast
=
None
,
saturation
=
None
,
hue
=
None
)),
(
v2_transforms
.
AutoAugment
,
ArgsKwargs
(
5
)),
]
],
)
@
get_params_parametrization
def
test_get_params_alias
(
config
,
get_params_args_kwargs
):
assert
config
.
prototype_cls
.
get_params
is
config
.
legacy_cls
.
get_params
if
not
config
.
args_kwargs
:
return
args
,
kwargs
=
config
.
args_kwargs
[
0
]
legacy_transform
=
config
.
legacy_cls
(
*
args
,
**
kwargs
)
prototype_transform
=
config
.
prototype_cls
(
*
args
,
**
kwargs
)
assert
prototype_transform
.
get_params
is
legacy_transform
.
get_params
@
get_params_parametrization
def
test_get_params_jit
(
config
,
get_params_args_kwargs
):
get_params_args
,
get_params_kwargs
=
get_params_args_kwargs
torch
.
jit
.
script
(
config
.
prototype_cls
.
get_params
)(
*
get_params_args
,
**
get_params_kwargs
)
if
not
config
.
args_kwargs
:
return
args
,
kwargs
=
config
.
args_kwargs
[
0
]
transform
=
config
.
prototype_cls
(
*
args
,
**
kwargs
)
torch
.
jit
.
script
(
transform
.
get_params
)(
*
get_params_args
,
**
get_params_kwargs
)
@
pytest
.
mark
.
parametrize
(
(
"config"
,
"args_kwargs"
),
[
...
...
test/test_transforms_v2_refactored.py
View file @
b6574c92
...
...
@@ -3881,3 +3881,67 @@ class TestPerspective:
)
assert_close
(
actual
,
expected
,
rtol
=
0
,
atol
=
1
)
class
TestColorJitter
:
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_video
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
uint8
,
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_transform
(
self
,
make_input
,
dtype
,
device
):
if
make_input
is
make_image_pil
and
not
(
dtype
is
torch
.
uint8
and
device
==
"cpu"
):
pytest
.
skip
(
"PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
"will degenerate to that anyway."
)
check_transform
(
transforms
.
ColorJitter
(
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0.25
),
make_input
(
dtype
=
dtype
,
device
=
device
),
)
def
test_transform_noop
(
self
):
input
=
make_image
()
input_version
=
input
.
_version
transform
=
transforms
.
ColorJitter
()
output
=
transform
(
input
)
assert
output
is
input
assert
output
.
data_ptr
()
==
input
.
data_ptr
()
assert
output
.
_version
==
input_version
def
test_transform_error
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"must be non negative"
):
transforms
.
ColorJitter
(
brightness
=-
1
)
for
brightness
in
[
object
(),
[
1
,
2
,
3
]]:
with
pytest
.
raises
(
TypeError
,
match
=
"single number or a sequence with length 2"
):
transforms
.
ColorJitter
(
brightness
=
brightness
)
with
pytest
.
raises
(
ValueError
,
match
=
"values should be between"
):
transforms
.
ColorJitter
(
brightness
=
(
-
1
,
0.5
))
with
pytest
.
raises
(
ValueError
,
match
=
"values should be between"
):
transforms
.
ColorJitter
(
hue
=
1
)
@
pytest
.
mark
.
parametrize
(
"brightness"
,
[
None
,
0.1
,
(
0.2
,
0.3
)])
@
pytest
.
mark
.
parametrize
(
"contrast"
,
[
None
,
0.4
,
(
0.5
,
0.6
)])
@
pytest
.
mark
.
parametrize
(
"saturation"
,
[
None
,
0.7
,
(
0.8
,
0.9
)])
@
pytest
.
mark
.
parametrize
(
"hue"
,
[
None
,
0.3
,
(
-
0.1
,
0.2
)])
def
test_transform_correctness
(
self
,
brightness
,
contrast
,
saturation
,
hue
):
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
transform
=
transforms
.
ColorJitter
(
brightness
=
brightness
,
contrast
=
contrast
,
saturation
=
saturation
,
hue
=
hue
)
with
freeze_rng_state
():
torch
.
manual_seed
(
0
)
actual
=
transform
(
image
)
torch
.
manual_seed
(
0
)
expected
=
F
.
to_image
(
transform
(
F
.
to_pil_image
(
image
)))
mae
=
(
actual
.
float
()
-
expected
.
float
()).
abs
().
mean
()
assert
mae
<
2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment