Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2d6e663a
Unverified
Commit
2d6e663a
authored
Feb 07, 2023
by
Philip Meier
Committed by
GitHub
Feb 07, 2023
Browse files
make transforms v2 get_params a staticmethod (#7177)
Co-authored-by:
Nicolas Hug
<
contact@nicolas-hug.com
>
parent
bac678c8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
27 deletions
+48
-27
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+47
-26
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+1
-1
No files found.
test/test_prototype_transforms_consistency.py
View file @
2d6e663a
...
...
@@ -649,37 +649,58 @@ def test_call_consistency(config, args_kwargs):
)
@
pytest
.
mark
.
parametrize
(
"config"
,
[
config
for
config
in
CONSISTENCY_CONFIGS
if
hasattr
(
config
.
legacy_cls
,
"get_params"
)],
ids
=
lambda
config
:
config
.
legacy_cls
.
__name__
,
get_params_parametrization
=
pytest
.
mark
.
parametrize
(
(
"config"
,
"get_params_args_kwargs"
),
[
pytest
.
param
(
next
(
config
for
config
in
CONSISTENCY_CONFIGS
if
config
.
prototype_cls
is
transform_cls
),
get_params_args_kwargs
,
id
=
transform_cls
.
__name__
,
)
for
transform_cls
,
get_params_args_kwargs
in
[
(
prototype_transforms
.
RandomResizedCrop
,
ArgsKwargs
(
make_image
(),
scale
=
[
0.3
,
0.7
],
ratio
=
[
0.5
,
1.5
])),
(
prototype_transforms
.
RandomErasing
,
ArgsKwargs
(
make_image
(),
scale
=
(
0.3
,
0.7
),
ratio
=
(
0.5
,
1.5
))),
(
prototype_transforms
.
ColorJitter
,
ArgsKwargs
(
brightness
=
None
,
contrast
=
None
,
saturation
=
None
,
hue
=
None
)),
(
prototype_transforms
.
ElasticTransform
,
ArgsKwargs
(
alpha
=
[
15.3
,
27.2
],
sigma
=
[
2.5
,
3.9
],
size
=
[
17
,
31
])),
(
prototype_transforms
.
GaussianBlur
,
ArgsKwargs
(
0.3
,
1.4
)),
(
prototype_transforms
.
RandomAffine
,
ArgsKwargs
(
degrees
=
[
-
20.0
,
10.0
],
translate
=
None
,
scale_ranges
=
None
,
shears
=
None
,
img_size
=
[
15
,
29
]),
),
(
prototype_transforms
.
RandomCrop
,
ArgsKwargs
(
make_image
(
size
=
(
61
,
47
)),
output_size
=
(
19
,
25
))),
(
prototype_transforms
.
RandomPerspective
,
ArgsKwargs
(
23
,
17
,
0.5
)),
(
prototype_transforms
.
RandomRotation
,
ArgsKwargs
(
degrees
=
[
-
20.0
,
10.0
])),
(
prototype_transforms
.
AutoAugment
,
ArgsKwargs
(
5
)),
]
],
)
def
test_get_params_alias
(
config
):
@
get_paramsl_parametrization
def
test_get_params_alias
(
config
,
get_params_args_kwargs
):
assert
config
.
prototype_cls
.
get_params
is
config
.
legacy_cls
.
get_params
if
not
config
.
args_kwargs
:
return
args
,
kwargs
=
config
.
args_kwargs
[
0
]
legacy_transform
=
config
.
legacy_cls
(
*
args
,
**
kwargs
)
prototype_transform
=
config
.
prototype_cls
(
*
args
,
**
kwargs
)
@
pytest
.
mark
.
parametrize
(
(
"transform_cls"
,
"args_kwargs"
),
[
(
prototype_transforms
.
RandomResizedCrop
,
ArgsKwargs
(
make_image
(),
scale
=
[
0.3
,
0.7
],
ratio
=
[
0.5
,
1.5
])),
(
prototype_transforms
.
RandomErasing
,
ArgsKwargs
(
make_image
(),
scale
=
(
0.3
,
0.7
),
ratio
=
(
0.5
,
1.5
))),
(
prototype_transforms
.
ColorJitter
,
ArgsKwargs
(
brightness
=
None
,
contrast
=
None
,
saturation
=
None
,
hue
=
None
)),
(
prototype_transforms
.
ElasticTransform
,
ArgsKwargs
(
alpha
=
[
15.3
,
27.2
],
sigma
=
[
2.5
,
3.9
],
size
=
[
17
,
31
])),
(
prototype_transforms
.
GaussianBlur
,
ArgsKwargs
(
0.3
,
1.4
)),
(
prototype_transforms
.
RandomAffine
,
ArgsKwargs
(
degrees
=
[
-
20.0
,
10.0
],
translate
=
None
,
scale_ranges
=
None
,
shears
=
None
,
img_size
=
[
15
,
29
]),
),
(
prototype_transforms
.
RandomCrop
,
ArgsKwargs
(
make_image
(
size
=
(
61
,
47
)),
output_size
=
(
19
,
25
))),
(
prototype_transforms
.
RandomPerspective
,
ArgsKwargs
(
23
,
17
,
0.5
)),
(
prototype_transforms
.
RandomRotation
,
ArgsKwargs
(
degrees
=
[
-
20.0
,
10.0
])),
(
prototype_transforms
.
AutoAugment
,
ArgsKwargs
(
5
)),
],
)
def
test_get_params_jit
(
transform_cls
,
args_kwargs
):
args
,
kwargs
=
args_kwargs
assert
prototype_transform
.
get_params
is
legacy_transform
.
get_params
@
get_paramsl_parametrization
def
test_get_params_jit
(
config
,
get_params_args_kwargs
):
get_params_args
,
get_params_kwargs
=
get_params_args_kwargs
torch
.
jit
.
script
(
config
.
prototype_cls
.
get_params
)(
*
get_params_args
,
**
get_params_kwargs
)
if
not
config
.
args_kwargs
:
return
args
,
kwargs
=
config
.
args_kwargs
[
0
]
transform
=
config
.
prototype_cls
(
*
args
,
**
kwargs
)
torch
.
jit
.
script
(
transform
_cls
.
get_params
)(
*
args
,
**
kwargs
)
torch
.
jit
.
script
(
transform
.
get_params
)(
*
get_params_args
,
**
get_params_
kwargs
)
@
pytest
.
mark
.
parametrize
(
...
...
torchvision/prototype/transforms/_transform.py
View file @
2d6e663a
...
...
@@ -67,7 +67,7 @@ class Transform(nn.Module):
# Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
# This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
if
cls
.
_v1_transform_cls
is
not
None
and
hasattr
(
cls
.
_v1_transform_cls
,
"get_params"
):
cls
.
get_params
=
cls
.
_v1_transform_cls
.
get_params
# type: ignore[attr-defined]
cls
.
get_params
=
staticmethod
(
cls
.
_v1_transform_cls
.
get_params
)
# type: ignore[attr-defined]
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment