Unverified Commit a46d97c9 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

align transforms v2 signatures with v1 (#7301)


Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 49c6961a
...@@ -540,9 +540,12 @@ def test_signature_consistency(config): ...@@ -540,9 +540,12 @@ def test_signature_consistency(config):
f"not. Please add a default value." f"not. Please add a default value."
) )
legacy_kinds = {name: param.kind for name, param in legacy_params.items()} legacy_signature = list(legacy_params.keys())
prototype_kinds = {name: prototype_params[name].kind for name in legacy_kinds.keys()} # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
assert prototype_kinds == legacy_kinds # to the same number of parameters as the legacy one
prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]
assert prototype_signature == legacy_signature
def check_call_consistency( def check_call_consistency(
......
...@@ -124,8 +124,8 @@ class RandomChoice(Transform): ...@@ -124,8 +124,8 @@ class RandomChoice(Transform):
def __init__( def __init__(
self, self,
transforms: Sequence[Callable], transforms: Sequence[Callable],
probabilities: Optional[List[float]] = None,
p: Optional[List[float]] = None, p: Optional[List[float]] = None,
probabilities: Optional[List[float]] = None,
) -> None: ) -> None:
if not isinstance(transforms, Sequence): if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables") raise TypeError("Argument transforms should be a sequence of callables")
......
...@@ -575,8 +575,8 @@ class RandomRotation(Transform): ...@@ -575,8 +575,8 @@ class RandomRotation(Transform):
degrees: Union[numbers.Number, Sequence], degrees: Union[numbers.Number, Sequence],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
...@@ -903,9 +903,9 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -903,9 +903,9 @@ class RandomPerspective(_RandomApplyTransform):
def __init__( def __init__(
self, self,
distortion_scale: float = 0.5, distortion_scale: float = 0.5,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
p: float = 0.5, p: float = 0.5,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
) -> None: ) -> None:
super().__init__(p=p) super().__init__(p=p)
...@@ -966,8 +966,8 @@ class ElasticTransform(Transform): ...@@ -966,8 +966,8 @@ class ElasticTransform(Transform):
self, self,
alpha: Union[float, Sequence[float]] = 50.0, alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0, sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2) self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment