Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b6574c92
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "58237364b1780223f48a80256f56408efe7b59a0"
Unverified
Commit
b6574c92
authored
Sep 18, 2023
by
Philip Meier
Committed by
GitHub
Sep 18, 2023
Browse files
port tests for transforms.ColorJitter (#7968)
parent
5fa8050d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
60 deletions
+64
-60
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+0
-60
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+64
-0
No files found.
test/test_transforms_v2_consistency.py
View file @
b6574c92
...
@@ -228,23 +228,6 @@ CONSISTENCY_CONFIGS = [
...
@@ -228,23 +228,6 @@ CONSISTENCY_CONFIGS = [
# Use default tolerances of `torch.testing.assert_close`
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs
=
dict
(
rtol
=
None
,
atol
=
None
),
closeness_kwargs
=
dict
(
rtol
=
None
,
atol
=
None
),
),
),
ConsistencyConfig
(
v2_transforms
.
ColorJitter
,
legacy_transforms
.
ColorJitter
,
[
ArgsKwargs
(),
ArgsKwargs
(
brightness
=
0.1
),
ArgsKwargs
(
brightness
=
(
0.2
,
0.3
)),
ArgsKwargs
(
contrast
=
0.4
),
ArgsKwargs
(
contrast
=
(
0.5
,
0.6
)),
ArgsKwargs
(
saturation
=
0.7
),
ArgsKwargs
(
saturation
=
(
0.8
,
0.9
)),
ArgsKwargs
(
hue
=
0.3
),
ArgsKwargs
(
hue
=
(
-
0.1
,
0.2
)),
ArgsKwargs
(
brightness
=
0.1
,
contrast
=
0.4
,
saturation
=
0.5
,
hue
=
0.3
),
],
closeness_kwargs
=
{
"atol"
:
1e-5
,
"rtol"
:
1e-5
},
),
ConsistencyConfig
(
ConsistencyConfig
(
v2_transforms
.
PILToTensor
,
v2_transforms
.
PILToTensor
,
legacy_transforms
.
PILToTensor
,
legacy_transforms
.
PILToTensor
,
...
@@ -453,49 +436,6 @@ def test_call_consistency(config, args_kwargs):
...
@@ -453,49 +436,6 @@ def test_call_consistency(config, args_kwargs):
)
)
get_params_parametrization
=
pytest
.
mark
.
parametrize
(
(
"config"
,
"get_params_args_kwargs"
),
[
pytest
.
param
(
next
(
config
for
config
in
CONSISTENCY_CONFIGS
if
config
.
prototype_cls
is
transform_cls
),
get_params_args_kwargs
,
id
=
transform_cls
.
__name__
,
)
for
transform_cls
,
get_params_args_kwargs
in
[
(
v2_transforms
.
ColorJitter
,
ArgsKwargs
(
brightness
=
None
,
contrast
=
None
,
saturation
=
None
,
hue
=
None
)),
(
v2_transforms
.
AutoAugment
,
ArgsKwargs
(
5
)),
]
],
)
@
get_params_parametrization
def
test_get_params_alias
(
config
,
get_params_args_kwargs
):
assert
config
.
prototype_cls
.
get_params
is
config
.
legacy_cls
.
get_params
if
not
config
.
args_kwargs
:
return
args
,
kwargs
=
config
.
args_kwargs
[
0
]
legacy_transform
=
config
.
legacy_cls
(
*
args
,
**
kwargs
)
prototype_transform
=
config
.
prototype_cls
(
*
args
,
**
kwargs
)
assert
prototype_transform
.
get_params
is
legacy_transform
.
get_params
@
get_params_parametrization
def
test_get_params_jit
(
config
,
get_params_args_kwargs
):
get_params_args
,
get_params_kwargs
=
get_params_args_kwargs
torch
.
jit
.
script
(
config
.
prototype_cls
.
get_params
)(
*
get_params_args
,
**
get_params_kwargs
)
if
not
config
.
args_kwargs
:
return
args
,
kwargs
=
config
.
args_kwargs
[
0
]
transform
=
config
.
prototype_cls
(
*
args
,
**
kwargs
)
torch
.
jit
.
script
(
transform
.
get_params
)(
*
get_params_args
,
**
get_params_kwargs
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"config"
,
"args_kwargs"
),
(
"config"
,
"args_kwargs"
),
[
[
...
...
test/test_transforms_v2_refactored.py
View file @
b6574c92
...
@@ -3881,3 +3881,67 @@ class TestPerspective:
...
@@ -3881,3 +3881,67 @@ class TestPerspective:
)
)
assert_close
(
actual
,
expected
,
rtol
=
0
,
atol
=
1
)
assert_close
(
actual
,
expected
,
rtol
=
0
,
atol
=
1
)
class
TestColorJitter
:
@
pytest
.
mark
.
parametrize
(
"make_input"
,
[
make_image_tensor
,
make_image_pil
,
make_image
,
make_video
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
uint8
,
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_transform
(
self
,
make_input
,
dtype
,
device
):
if
make_input
is
make_image_pil
and
not
(
dtype
is
torch
.
uint8
and
device
==
"cpu"
):
pytest
.
skip
(
"PIL image tests with parametrization other than dtype=torch.uint8 and device='cpu' "
"will degenerate to that anyway."
)
check_transform
(
transforms
.
ColorJitter
(
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0.25
),
make_input
(
dtype
=
dtype
,
device
=
device
),
)
def
test_transform_noop
(
self
):
input
=
make_image
()
input_version
=
input
.
_version
transform
=
transforms
.
ColorJitter
()
output
=
transform
(
input
)
assert
output
is
input
assert
output
.
data_ptr
()
==
input
.
data_ptr
()
assert
output
.
_version
==
input_version
def
test_transform_error
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"must be non negative"
):
transforms
.
ColorJitter
(
brightness
=-
1
)
for
brightness
in
[
object
(),
[
1
,
2
,
3
]]:
with
pytest
.
raises
(
TypeError
,
match
=
"single number or a sequence with length 2"
):
transforms
.
ColorJitter
(
brightness
=
brightness
)
with
pytest
.
raises
(
ValueError
,
match
=
"values should be between"
):
transforms
.
ColorJitter
(
brightness
=
(
-
1
,
0.5
))
with
pytest
.
raises
(
ValueError
,
match
=
"values should be between"
):
transforms
.
ColorJitter
(
hue
=
1
)
@
pytest
.
mark
.
parametrize
(
"brightness"
,
[
None
,
0.1
,
(
0.2
,
0.3
)])
@
pytest
.
mark
.
parametrize
(
"contrast"
,
[
None
,
0.4
,
(
0.5
,
0.6
)])
@
pytest
.
mark
.
parametrize
(
"saturation"
,
[
None
,
0.7
,
(
0.8
,
0.9
)])
@
pytest
.
mark
.
parametrize
(
"hue"
,
[
None
,
0.3
,
(
-
0.1
,
0.2
)])
def
test_transform_correctness
(
self
,
brightness
,
contrast
,
saturation
,
hue
):
image
=
make_image
(
dtype
=
torch
.
uint8
,
device
=
"cpu"
)
transform
=
transforms
.
ColorJitter
(
brightness
=
brightness
,
contrast
=
contrast
,
saturation
=
saturation
,
hue
=
hue
)
with
freeze_rng_state
():
torch
.
manual_seed
(
0
)
actual
=
transform
(
image
)
torch
.
manual_seed
(
0
)
expected
=
F
.
to_image
(
transform
(
F
.
to_pil_image
(
image
)))
mae
=
(
actual
.
float
()
-
expected
.
float
()).
abs
().
mean
()
assert
mae
<
2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment