Unverified Commit f30a92f1 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add a signature consistency tests for v1 vs. v2 dispatchers (#6914)

* add a signature consistency tests for v1 vs. v2 dispatchers

* temporarily increase test verbosity

* Revert "temporarily increase test verbosity"

This reverts commit 468c73f727af8dde8a2a3cc3063fd64923c50f63.

* fix test to allow annotation deviations

* fill <-> center for rotate

* ignore annotation changes for center / translate in rotate / affine
parent b80f83db
...@@ -25,9 +25,10 @@ from prototype_common_utils import ( ...@@ -25,9 +25,10 @@ from prototype_common_utils import (
from torchvision import transforms as legacy_transforms from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms import functional as F from torchvision.prototype.transforms import functional as prototype_F
from torchvision.prototype.transforms._utils import query_spatial_size from torchvision.prototype.transforms._utils import query_spatial_size
from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.transforms import functional as legacy_F
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
...@@ -985,7 +986,7 @@ class PadIfSmaller(prototype_transforms.Transform): ...@@ -985,7 +986,7 @@ class PadIfSmaller(prototype_transforms.Transform):
return inpt return inpt
fill = self.fill[type(inpt)] fill = self.fill[type(inpt)]
return F.pad(inpt, padding=params["padding"], fill=fill) return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
class TestRefSegTransforms: class TestRefSegTransforms:
...@@ -1119,3 +1120,81 @@ class TestRefSegTransforms: ...@@ -1119,3 +1120,81 @@ class TestRefSegTransforms:
t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size) t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size)
self.check_resize(mocker, t_ref, t) self.check_resize(mocker, t_ref, t)
@pytest.mark.parametrize(
("legacy_dispatcher", "name_only_params"),
[
(legacy_F.get_dimensions, {}),
(legacy_F.get_image_size, {}),
(legacy_F.get_image_num_channels, {}),
(legacy_F.to_tensor, {}),
(legacy_F.pil_to_tensor, {}),
(legacy_F.convert_image_dtype, {}),
(legacy_F.to_pil_image, {}),
(legacy_F.normalize, {}),
(legacy_F.resize, {}),
(legacy_F.pad, {"padding", "fill"}),
(legacy_F.crop, {}),
(legacy_F.center_crop, {}),
(legacy_F.resized_crop, {}),
(legacy_F.hflip, {}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill"}),
(legacy_F.vflip, {}),
(legacy_F.five_crop, {}),
(legacy_F.ten_crop, {}),
(legacy_F.adjust_brightness, {}),
(legacy_F.adjust_contrast, {}),
(legacy_F.adjust_saturation, {}),
(legacy_F.adjust_hue, {}),
(legacy_F.adjust_gamma, {}),
(legacy_F.rotate, {"center", "fill"}),
(legacy_F.affine, {"angle", "translate", "center", "fill"}),
(legacy_F.to_grayscale, {}),
(legacy_F.rgb_to_grayscale, {}),
(legacy_F.to_tensor, {}),
(legacy_F.erase, {}),
(legacy_F.gaussian_blur, {}),
(legacy_F.invert, {}),
(legacy_F.posterize, {}),
(legacy_F.solarize, {}),
(legacy_F.adjust_sharpness, {}),
(legacy_F.autocontrast, {}),
(legacy_F.equalize, {}),
(legacy_F.elastic_transform, {"fill"}),
],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
legacy_signature = inspect.signature(legacy_dispatcher)
legacy_params = list(legacy_signature.parameters.values())[1:]
try:
prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
except AttributeError:
raise AssertionError(
f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
) from None
prototype_signature = inspect.signature(prototype_dispatcher)
prototype_params = list(prototype_signature.parameters.values())[1:]
# Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
# need to check if parameters were added in the middle rather than at the end, since that will be caught by the
# regular check below.
prototype_params, new_prototype_params = (
prototype_params[: len(legacy_params)],
prototype_params[len(legacy_params) :],
)
for param in new_prototype_params:
assert param.default is not param.empty
# Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
# annotations. In these cases we simply drop the annotation and default argument from the comparison
for prototype_param, legacy_param in zip(prototype_params, legacy_params):
if legacy_param.name in name_only_params:
prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
elif legacy_param.annotation is inspect.Parameter.empty:
prototype_param._annotation = inspect.Parameter.empty
assert prototype_params == legacy_params
...@@ -132,8 +132,8 @@ class BoundingBox(_Feature): ...@@ -132,8 +132,8 @@ class BoundingBox(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
) -> BoundingBox: ) -> BoundingBox:
output, spatial_size = self._F.rotate_bounding_box( output, spatial_size = self._F.rotate_bounding_box(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
......
...@@ -199,8 +199,8 @@ class _Feature(torch.Tensor): ...@@ -199,8 +199,8 @@ class _Feature(torch.Tensor):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
) -> _Feature: ) -> _Feature:
return self return self
......
...@@ -174,8 +174,8 @@ class Image(_Feature): ...@@ -174,8 +174,8 @@ class Image(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
) -> Image: ) -> Image:
output = self._F.rotate_image_tensor( output = self._F.rotate_image_tensor(
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
......
...@@ -89,8 +89,8 @@ class Mask(_Feature): ...@@ -89,8 +89,8 @@ class Mask(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
) -> Mask: ) -> Mask:
output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill) output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
return Mask.wrap_like(self, output) return Mask.wrap_like(self, output)
......
...@@ -134,8 +134,8 @@ class Video(_Feature): ...@@ -134,8 +134,8 @@ class Video(_Feature):
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
) -> Video: ) -> Video:
output = self._F.rotate_video( output = self._F.rotate_video(
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
......
...@@ -305,8 +305,8 @@ class RandomRotation(Transform): ...@@ -305,8 +305,8 @@ class RandomRotation(Transform):
**params, **params,
interpolation=self.interpolation, interpolation=self.interpolation,
expand=self.expand, expand=self.expand,
fill=fill,
center=self.center, center=self.center,
fill=fill,
) )
......
...@@ -521,8 +521,8 @@ def rotate_image_tensor( ...@@ -521,8 +521,8 @@ def rotate_image_tensor(
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
shape = image.shape shape = image.shape
num_channels, height, width = shape[-3:] num_channels, height, width = shape[-3:]
...@@ -560,8 +560,8 @@ def rotate_image_pil( ...@@ -560,8 +560,8 @@ def rotate_image_pil(
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: features.FillTypeJIT = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
if center is not None and expand: if center is not None and expand:
warnings.warn("The provided center argument has no effect on the result if expand is True") warnings.warn("The provided center argument has no effect on the result if expand is True")
...@@ -612,8 +612,8 @@ def rotate_mask( ...@@ -612,8 +612,8 @@ def rotate_mask(
mask: torch.Tensor, mask: torch.Tensor,
angle: float, angle: float,
expand: bool = False, expand: bool = False,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
...@@ -641,8 +641,8 @@ def rotate_video( ...@@ -641,8 +641,8 @@ def rotate_video(
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
...@@ -652,8 +652,8 @@ def rotate( ...@@ -652,8 +652,8 @@ def rotate(
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: features.FillTypeJIT = None,
) -> features.InputTypeJIT: ) -> features.InputTypeJIT:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment