"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "84bbb7140e03df01b3bb388ba4df299328ea2dff"
Unverified Commit 2d6e663a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make transforms v2 get_params a staticmethod (#7177)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent bac678c8
...@@ -649,37 +649,58 @@ def test_call_consistency(config, args_kwargs): ...@@ -649,37 +649,58 @@ def test_call_consistency(config, args_kwargs):
) )
@pytest.mark.parametrize( get_params_parametrization = pytest.mark.parametrize(
"config", ("config", "get_params_args_kwargs"),
[config for config in CONSISTENCY_CONFIGS if hasattr(config.legacy_cls, "get_params")], [
ids=lambda config: config.legacy_cls.__name__, pytest.param(
next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls),
get_params_args_kwargs,
id=transform_cls.__name__,
)
for transform_cls, get_params_args_kwargs in [
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
(
prototype_transforms.RandomAffine,
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
),
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
(prototype_transforms.AutoAugment, ArgsKwargs(5)),
]
],
) )
def test_get_params_alias(config):
@get_paramsl_parametrization
def test_get_params_alias(config, get_params_args_kwargs):
assert config.prototype_cls.get_params is config.legacy_cls.get_params assert config.prototype_cls.get_params is config.legacy_cls.get_params
if not config.args_kwargs:
return
args, kwargs = config.args_kwargs[0]
legacy_transform = config.legacy_cls(*args, **kwargs)
prototype_transform = config.prototype_cls(*args, **kwargs)
@pytest.mark.parametrize( assert prototype_transform.get_params is legacy_transform.get_params
("transform_cls", "args_kwargs"),
[
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), @get_paramsl_parametrization
(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))), def test_get_params_jit(config, get_params_args_kwargs):
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)), get_params_args, get_params_kwargs = get_params_args_kwargs
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)), torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs)
(
prototype_transforms.RandomAffine, if not config.args_kwargs:
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]), return
), args, kwargs = config.args_kwargs[0]
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))), transform = config.prototype_cls(*args, **kwargs)
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
(prototype_transforms.AutoAugment, ArgsKwargs(5)),
],
)
def test_get_params_jit(transform_cls, args_kwargs):
args, kwargs = args_kwargs
torch.jit.script(transform_cls.get_params)(*args, **kwargs) torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -67,7 +67,7 @@ class Transform(nn.Module): ...@@ -67,7 +67,7 @@ class Transform(nn.Module):
# Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance. # Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
# This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`. # This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"): if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
cls.get_params = cls._v1_transform_cls.get_params # type: ignore[attr-defined] cls.get_params = staticmethod(cls._v1_transform_cls.get_params) # type: ignore[attr-defined]
def _extract_params_for_v1_transform(self) -> Dict[str, Any]: def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment