Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
82c51c48
Unverified
Commit
82c51c48
authored
Feb 01, 2023
by
Philip Meier
Committed by
GitHub
Feb 01, 2023
Browse files
enable get_params alias for transforms v2 (#7153)
parent
6bd04f65
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
2 deletions
+44
-2
test/test_prototype_transforms_consistency.py
test/test_prototype_transforms_consistency.py
+33
-0
torchvision/prototype/transforms/_transform.py
torchvision/prototype/transforms/_transform.py
+11
-2
No files found.
test/test_prototype_transforms_consistency.py
View file @
82c51c48
...
@@ -655,6 +655,39 @@ def test_call_consistency(config, args_kwargs):
...
@@ -655,6 +655,39 @@ def test_call_consistency(config, args_kwargs):
)
)
@
pytest
.
mark
.
parametrize
(
"config"
,
[
config
for
config
in
CONSISTENCY_CONFIGS
if
hasattr
(
config
.
legacy_cls
,
"get_params"
)],
ids
=
lambda
config
:
config
.
legacy_cls
.
__name__
,
)
def
test_get_params_alias
(
config
):
assert
config
.
prototype_cls
.
get_params
is
config
.
legacy_cls
.
get_params
@
pytest
.
mark
.
parametrize
(
(
"transform_cls"
,
"args_kwargs"
),
[
(
prototype_transforms
.
RandomResizedCrop
,
ArgsKwargs
(
make_image
(),
scale
=
[
0.3
,
0.7
],
ratio
=
[
0.5
,
1.5
])),
(
prototype_transforms
.
RandomErasing
,
ArgsKwargs
(
make_image
(),
scale
=
(
0.3
,
0.7
),
ratio
=
(
0.5
,
1.5
))),
(
prototype_transforms
.
ColorJitter
,
ArgsKwargs
(
brightness
=
None
,
contrast
=
None
,
saturation
=
None
,
hue
=
None
)),
(
prototype_transforms
.
ElasticTransform
,
ArgsKwargs
(
alpha
=
[
15.3
,
27.2
],
sigma
=
[
2.5
,
3.9
],
size
=
[
17
,
31
])),
(
prototype_transforms
.
GaussianBlur
,
ArgsKwargs
(
0.3
,
1.4
)),
(
prototype_transforms
.
RandomAffine
,
ArgsKwargs
(
degrees
=
[
-
20.0
,
10.0
],
translate
=
None
,
scale_ranges
=
None
,
shears
=
None
,
img_size
=
[
15
,
29
]),
),
(
prototype_transforms
.
RandomCrop
,
ArgsKwargs
(
make_image
(
size
=
(
61
,
47
)),
output_size
=
(
19
,
25
))),
(
prototype_transforms
.
RandomPerspective
,
ArgsKwargs
(
23
,
17
,
0.5
)),
(
prototype_transforms
.
RandomRotation
,
ArgsKwargs
(
degrees
=
[
-
20.0
,
10.0
])),
(
prototype_transforms
.
AutoAugment
,
ArgsKwargs
(
5
)),
],
)
def
test_get_params_jit
(
transform_cls
,
args_kwargs
):
args
,
kwargs
=
args_kwargs
torch
.
jit
.
script
(
transform_cls
.
get_params
)(
*
args
,
**
kwargs
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"config"
,
"args_kwargs"
),
(
"config"
,
"args_kwargs"
),
[
[
...
...
torchvision/prototype/transforms/_transform.py
View file @
82c51c48
...
@@ -56,10 +56,19 @@ class Transform(nn.Module):
...
@@ -56,10 +56,19 @@ class Transform(nn.Module):
return
", "
.
join
(
extra
)
return
", "
.
join
(
extra
)
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables the v2 transformation
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things:
# to be scriptable. See `_extract_params_for_v1_transform()` and `__prepare_scriptable__` for details.
# 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on
# the v2 transform. See `__init_subclass__` for details.
# 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__`
# for details.
_v1_transform_cls
:
Optional
[
Type
[
nn
.
Module
]]
=
None
_v1_transform_cls
:
Optional
[
Type
[
nn
.
Module
]]
=
None
def
__init_subclass__
(
cls
)
->
None
:
# Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
# This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
if
cls
.
_v1_transform_cls
is
not
None
and
hasattr
(
cls
.
_v1_transform_cls
,
"get_params"
):
cls
.
get_params
=
cls
.
_v1_transform_cls
.
get_params
# type: ignore[attr-defined]
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
def
_extract_params_for_v1_transform
(
self
)
->
Dict
[
str
,
Any
]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It does two things:
# v2 transform instance. It does two things:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment