Unverified Commit 0e0a5dc7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Support integer values for interpolation in the prototype transforms (#7248)

parent f627b9d1
...@@ -1534,7 +1534,7 @@ class TestScaleJitter: ...@@ -1534,7 +1534,7 @@ class TestScaleJitter:
assert int(spatial_size[1] * r_min) <= width <= int(spatial_size[1] * r_max) assert int(spatial_size[1] * r_min) <= width <= int(spatial_size[1] * r_max)
def test__transform(self, mocker): def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock() interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock() antialias_sentinel = mocker.MagicMock()
transform = transforms.ScaleJitter( transform = transforms.ScaleJitter(
...@@ -1581,7 +1581,7 @@ class TestRandomShortestSize: ...@@ -1581,7 +1581,7 @@ class TestRandomShortestSize:
assert shorter in min_size assert shorter in min_size
def test__transform(self, mocker): def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock() interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock() antialias_sentinel = mocker.MagicMock()
transform = transforms.RandomShortestSize( transform = transforms.RandomShortestSize(
...@@ -1945,7 +1945,7 @@ class TestRandomResize: ...@@ -1945,7 +1945,7 @@ class TestRandomResize:
assert min_size <= size < max_size assert min_size <= size < max_size
def test__transform(self, mocker): def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock() interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock() antialias_sentinel = mocker.MagicMock()
transform = transforms.RandomResize( transform = transforms.RandomResize(
......
...@@ -88,6 +88,9 @@ CONSISTENCY_CONFIGS = [ ...@@ -88,6 +88,9 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs((32, 29)), ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC),
NotScriptableArgsKwargs(31, max_size=32), NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32), ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100), NotScriptableArgsKwargs(30, max_size=100),
...@@ -305,6 +308,8 @@ CONSISTENCY_CONFIGS = [ ...@@ -305,6 +308,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(25, ratio=(0.5, 1.5)), ArgsKwargs(25, ratio=(0.5, 1.5)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC),
ArgsKwargs((29, 32), antialias=False), ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True), ArgsKwargs((28, 31), antialias=True),
], ],
...@@ -352,6 +357,8 @@ CONSISTENCY_CONFIGS = [ ...@@ -352,6 +357,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(sigma=(2.5, 3.9)), ArgsKwargs(sigma=(2.5, 3.9)),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC), ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs(interpolation=PIL.Image.NEAREST),
ArgsKwargs(interpolation=PIL.Image.BICUBIC),
ArgsKwargs(fill=1), ArgsKwargs(fill=1),
], ],
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
...@@ -386,6 +393,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -386,6 +393,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)), ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)), ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
ArgsKwargs(degrees=30.0, fill=1), ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)), ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
ArgsKwargs(degrees=30.0, center=(0, 0)), ArgsKwargs(degrees=30.0, center=(0, 0)),
...@@ -420,6 +428,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -420,6 +428,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=1), ArgsKwargs(p=1),
ArgsKwargs(p=1, distortion_scale=0.3), ArgsKwargs(p=1, distortion_scale=0.3),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.1, fill=1), ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)), ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
], ],
...@@ -432,6 +441,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -432,6 +441,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(degrees=30.0), ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)), ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR), ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
ArgsKwargs(degrees=30.0, expand=True), ArgsKwargs(degrees=30.0, expand=True),
ArgsKwargs(degrees=30.0, center=(0, 0)), ArgsKwargs(degrees=30.0, center=(0, 0)),
ArgsKwargs(degrees=30.0, fill=1), ArgsKwargs(degrees=30.0, fill=1),
...@@ -851,7 +861,11 @@ class TestAATransforms: ...@@ -851,7 +861,11 @@ class TestAATransforms:
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"interpolation", "interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], [
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
) )
def test_randaug(self, inpt, interpolation, mocker): def test_randaug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1) t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
...@@ -889,7 +903,11 @@ class TestAATransforms: ...@@ -889,7 +903,11 @@ class TestAATransforms:
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"interpolation", "interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], [
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
) )
def test_trivial_aug(self, inpt, interpolation, mocker): def test_trivial_aug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation) t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
...@@ -937,7 +955,11 @@ class TestAATransforms: ...@@ -937,7 +955,11 @@ class TestAATransforms:
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"interpolation", "interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], [
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
) )
def test_augmix(self, inpt, interpolation, mocker): def test_augmix(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
...@@ -986,7 +1008,11 @@ class TestAATransforms: ...@@ -986,7 +1008,11 @@ class TestAATransforms:
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"interpolation", "interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR], [
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
) )
def test_aa(self, inpt, interpolation): def test_aa(self, inpt, interpolation):
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
...@@ -1264,13 +1290,13 @@ class TestRefSegTransforms: ...@@ -1264,13 +1290,13 @@ class TestRefSegTransforms:
(legacy_F.convert_image_dtype, {}), (legacy_F.convert_image_dtype, {}),
(legacy_F.to_pil_image, {}), (legacy_F.to_pil_image, {}),
(legacy_F.normalize, {}), (legacy_F.normalize, {}),
(legacy_F.resize, {}), (legacy_F.resize, {"interpolation"}),
(legacy_F.pad, {"padding", "fill"}), (legacy_F.pad, {"padding", "fill"}),
(legacy_F.crop, {}), (legacy_F.crop, {}),
(legacy_F.center_crop, {}), (legacy_F.center_crop, {}),
(legacy_F.resized_crop, {}), (legacy_F.resized_crop, {"interpolation"}),
(legacy_F.hflip, {}), (legacy_F.hflip, {}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill"}), (legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
(legacy_F.vflip, {}), (legacy_F.vflip, {}),
(legacy_F.five_crop, {}), (legacy_F.five_crop, {}),
(legacy_F.ten_crop, {}), (legacy_F.ten_crop, {}),
...@@ -1279,8 +1305,8 @@ class TestRefSegTransforms: ...@@ -1279,8 +1305,8 @@ class TestRefSegTransforms:
(legacy_F.adjust_saturation, {}), (legacy_F.adjust_saturation, {}),
(legacy_F.adjust_hue, {}), (legacy_F.adjust_hue, {}),
(legacy_F.adjust_gamma, {}), (legacy_F.adjust_gamma, {}),
(legacy_F.rotate, {"center", "fill"}), (legacy_F.rotate, {"center", "fill", "interpolation"}),
(legacy_F.affine, {"angle", "translate", "center", "fill"}), (legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
(legacy_F.to_grayscale, {}), (legacy_F.to_grayscale, {}),
(legacy_F.rgb_to_grayscale, {}), (legacy_F.rgb_to_grayscale, {}),
(legacy_F.to_tensor, {}), (legacy_F.to_tensor, {}),
...@@ -1292,7 +1318,7 @@ class TestRefSegTransforms: ...@@ -1292,7 +1318,7 @@ class TestRefSegTransforms:
(legacy_F.adjust_sharpness, {}), (legacy_F.adjust_sharpness, {}),
(legacy_F.autocontrast, {}), (legacy_F.autocontrast, {}),
(legacy_F.equalize, {}), (legacy_F.equalize, {}),
(legacy_F.elastic_transform, {"fill"}), (legacy_F.elastic_transform, {"fill", "interpolation"}),
], ],
) )
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params): def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
......
...@@ -76,7 +76,7 @@ class BoundingBox(Datapoint): ...@@ -76,7 +76,7 @@ class BoundingBox(Datapoint):
def resize( # type: ignore[override] def resize( # type: ignore[override]
self, self,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox: ) -> BoundingBox:
...@@ -107,7 +107,7 @@ class BoundingBox(Datapoint): ...@@ -107,7 +107,7 @@ class BoundingBox(Datapoint):
height: int, height: int,
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox: ) -> BoundingBox:
output, spatial_size = self._F.resized_crop_bounding_box( output, spatial_size = self._F.resized_crop_bounding_box(
...@@ -133,7 +133,7 @@ class BoundingBox(Datapoint): ...@@ -133,7 +133,7 @@ class BoundingBox(Datapoint):
def rotate( def rotate(
self, self,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
...@@ -154,7 +154,7 @@ class BoundingBox(Datapoint): ...@@ -154,7 +154,7 @@ class BoundingBox(Datapoint):
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
...@@ -174,7 +174,7 @@ class BoundingBox(Datapoint): ...@@ -174,7 +174,7 @@ class BoundingBox(Datapoint):
self, self,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
...@@ -191,7 +191,7 @@ class BoundingBox(Datapoint): ...@@ -191,7 +191,7 @@ class BoundingBox(Datapoint):
def elastic( def elastic(
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.elastic_bounding_box( output = self._F.elastic_bounding_box(
......
...@@ -143,7 +143,7 @@ class Datapoint(torch.Tensor): ...@@ -143,7 +143,7 @@ class Datapoint(torch.Tensor):
def resize( # type: ignore[override] def resize( # type: ignore[override]
self, self,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint: ) -> Datapoint:
...@@ -162,7 +162,7 @@ class Datapoint(torch.Tensor): ...@@ -162,7 +162,7 @@ class Datapoint(torch.Tensor):
height: int, height: int,
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint: ) -> Datapoint:
return self return self
...@@ -178,7 +178,7 @@ class Datapoint(torch.Tensor): ...@@ -178,7 +178,7 @@ class Datapoint(torch.Tensor):
def rotate( def rotate(
self, self,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
...@@ -191,7 +191,7 @@ class Datapoint(torch.Tensor): ...@@ -191,7 +191,7 @@ class Datapoint(torch.Tensor):
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Datapoint: ) -> Datapoint:
...@@ -201,7 +201,7 @@ class Datapoint(torch.Tensor): ...@@ -201,7 +201,7 @@ class Datapoint(torch.Tensor):
self, self,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> Datapoint: ) -> Datapoint:
...@@ -210,7 +210,7 @@ class Datapoint(torch.Tensor): ...@@ -210,7 +210,7 @@ class Datapoint(torch.Tensor):
def elastic( def elastic(
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> Datapoint: ) -> Datapoint:
return self return self
......
...@@ -62,7 +62,7 @@ class Image(Datapoint): ...@@ -62,7 +62,7 @@ class Image(Datapoint):
def resize( # type: ignore[override] def resize( # type: ignore[override]
self, self,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> Image: ) -> Image:
...@@ -86,7 +86,7 @@ class Image(Datapoint): ...@@ -86,7 +86,7 @@ class Image(Datapoint):
height: int, height: int,
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> Image: ) -> Image:
output = self._F.resized_crop_image_tensor( output = self._F.resized_crop_image_tensor(
...@@ -113,7 +113,7 @@ class Image(Datapoint): ...@@ -113,7 +113,7 @@ class Image(Datapoint):
def rotate( def rotate(
self, self,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
...@@ -129,7 +129,7 @@ class Image(Datapoint): ...@@ -129,7 +129,7 @@ class Image(Datapoint):
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Image: ) -> Image:
...@@ -149,7 +149,7 @@ class Image(Datapoint): ...@@ -149,7 +149,7 @@ class Image(Datapoint):
self, self,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> Image: ) -> Image:
...@@ -166,7 +166,7 @@ class Image(Datapoint): ...@@ -166,7 +166,7 @@ class Image(Datapoint):
def elastic( def elastic(
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> Image: ) -> Image:
output = self._F.elastic_image_tensor( output = self._F.elastic_image_tensor(
......
...@@ -53,7 +53,7 @@ class Mask(Datapoint): ...@@ -53,7 +53,7 @@ class Mask(Datapoint):
def resize( # type: ignore[override] def resize( # type: ignore[override]
self, self,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> Mask: ) -> Mask:
...@@ -75,7 +75,7 @@ class Mask(Datapoint): ...@@ -75,7 +75,7 @@ class Mask(Datapoint):
height: int, height: int,
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> Mask: ) -> Mask:
output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size) output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
...@@ -93,7 +93,7 @@ class Mask(Datapoint): ...@@ -93,7 +93,7 @@ class Mask(Datapoint):
def rotate( def rotate(
self, self,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
...@@ -107,7 +107,7 @@ class Mask(Datapoint): ...@@ -107,7 +107,7 @@ class Mask(Datapoint):
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Mask: ) -> Mask:
...@@ -126,7 +126,7 @@ class Mask(Datapoint): ...@@ -126,7 +126,7 @@ class Mask(Datapoint):
self, self,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> Mask: ) -> Mask:
...@@ -138,7 +138,7 @@ class Mask(Datapoint): ...@@ -138,7 +138,7 @@ class Mask(Datapoint):
def elastic( def elastic(
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> Mask: ) -> Mask:
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill) output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
......
...@@ -57,7 +57,7 @@ class Video(Datapoint): ...@@ -57,7 +57,7 @@ class Video(Datapoint):
def resize( # type: ignore[override] def resize( # type: ignore[override]
self, self,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> Video: ) -> Video:
...@@ -85,7 +85,7 @@ class Video(Datapoint): ...@@ -85,7 +85,7 @@ class Video(Datapoint):
height: int, height: int,
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> Video: ) -> Video:
output = self._F.resized_crop_video( output = self._F.resized_crop_video(
...@@ -112,7 +112,7 @@ class Video(Datapoint): ...@@ -112,7 +112,7 @@ class Video(Datapoint):
def rotate( def rotate(
self, self,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
...@@ -128,7 +128,7 @@ class Video(Datapoint): ...@@ -128,7 +128,7 @@ class Video(Datapoint):
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Video: ) -> Video:
...@@ -148,7 +148,7 @@ class Video(Datapoint): ...@@ -148,7 +148,7 @@ class Video(Datapoint):
self, self,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> Video: ) -> Video:
...@@ -165,7 +165,7 @@ class Video(Datapoint): ...@@ -165,7 +165,7 @@ class Video(Datapoint):
def elastic( def elastic(
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> Video: ) -> Video:
output = self._F.elastic_video( output = self._F.elastic_video(
......
...@@ -10,6 +10,7 @@ from torchvision import transforms as _transforms ...@@ -10,6 +10,7 @@ from torchvision import transforms as _transforms
from torchvision.ops import masks_to_boxes from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size
...@@ -203,11 +204,11 @@ class SimpleCopyPaste(Transform): ...@@ -203,11 +204,11 @@ class SimpleCopyPaste(Transform):
def __init__( def __init__(
self, self,
blending: bool = True, blending: bool = True,
resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR, resize_interpolation: Union[int, InterpolationMode] = F.InterpolationMode.BILINEAR,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.resize_interpolation = resize_interpolation self.resize_interpolation = _check_interpolation(resize_interpolation)
self.blending = blending self.blending = blending
self.antialias = antialias self.antialias = antialias
......
...@@ -8,6 +8,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec ...@@ -8,6 +8,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision import transforms as _transforms from torchvision import transforms as _transforms
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from torchvision.prototype.transforms.functional._meta import get_spatial_size from torchvision.prototype.transforms.functional._meta import get_spatial_size
from torchvision.transforms import functional_tensor as _FT from torchvision.transforms import functional_tensor as _FT
...@@ -19,11 +20,11 @@ class _AutoAugmentBase(Transform): ...@@ -19,11 +20,11 @@ class _AutoAugmentBase(Transform):
def __init__( def __init__(
self, self,
*, *,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.interpolation = interpolation self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]: def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
...@@ -79,7 +80,7 @@ class _AutoAugmentBase(Transform): ...@@ -79,7 +80,7 @@ class _AutoAugmentBase(Transform):
image: Union[datapoints.ImageType, datapoints.VideoType], image: Union[datapoints.ImageType, datapoints.VideoType],
transform_id: str, transform_id: str,
magnitude: float, magnitude: float,
interpolation: InterpolationMode, interpolation: Union[InterpolationMode, int],
fill: Dict[Type, datapoints.FillTypeJIT], fill: Dict[Type, datapoints.FillTypeJIT],
) -> Union[datapoints.ImageType, datapoints.VideoType]: ) -> Union[datapoints.ImageType, datapoints.VideoType]:
fill_ = fill[type(image)] fill_ = fill[type(image)]
...@@ -193,7 +194,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -193,7 +194,7 @@ class AutoAugment(_AutoAugmentBase):
def __init__( def __init__(
self, self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
...@@ -350,7 +351,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -350,7 +351,7 @@ class RandAugment(_AutoAugmentBase):
num_ops: int = 2, num_ops: int = 2,
magnitude: int = 9, magnitude: int = 9,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
...@@ -403,7 +404,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -403,7 +404,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
def __init__( def __init__(
self, self,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
): ):
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
...@@ -461,7 +462,7 @@ class AugMix(_AutoAugmentBase): ...@@ -461,7 +462,7 @@ class AugMix(_AutoAugmentBase):
chain_depth: int = -1, chain_depth: int = -1,
alpha: float = 1.0, alpha: float = 1.0,
all_ops: bool = True, all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
......
...@@ -10,6 +10,7 @@ from torchvision import transforms as _transforms ...@@ -10,6 +10,7 @@ from torchvision import transforms as _transforms
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
...@@ -45,7 +46,7 @@ class Resize(Transform): ...@@ -45,7 +46,7 @@ class Resize(Transform):
def __init__( def __init__(
self, self,
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> None: ) -> None:
...@@ -61,7 +62,7 @@ class Resize(Transform): ...@@ -61,7 +62,7 @@ class Resize(Transform):
) )
self.size = size self.size = size
self.interpolation = interpolation self.interpolation = _check_interpolation(interpolation)
self.max_size = max_size self.max_size = max_size
self.antialias = antialias self.antialias = antialias
...@@ -94,7 +95,7 @@ class RandomResizedCrop(Transform): ...@@ -94,7 +95,7 @@ class RandomResizedCrop(Transform):
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
scale: Tuple[float, float] = (0.08, 1.0), scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -111,7 +112,7 @@ class RandomResizedCrop(Transform): ...@@ -111,7 +112,7 @@ class RandomResizedCrop(Transform):
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
self.interpolation = interpolation self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias self.antialias = antialias
self._log_ratio = torch.log(torch.tensor(self.ratio)) self._log_ratio = torch.log(torch.tensor(self.ratio))
...@@ -317,14 +318,14 @@ class RandomRotation(Transform): ...@@ -317,14 +318,14 @@ class RandomRotation(Transform):
def __init__( def __init__(
self, self,
degrees: Union[numbers.Number, Sequence], degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
self.interpolation = interpolation self.interpolation = _check_interpolation(interpolation)
self.expand = expand self.expand = expand
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
...@@ -359,7 +360,7 @@ class RandomAffine(Transform): ...@@ -359,7 +360,7 @@ class RandomAffine(Transform):
translate: Optional[Sequence[float]] = None, translate: Optional[Sequence[float]] = None,
scale: Optional[Sequence[float]] = None, scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
...@@ -383,7 +384,7 @@ class RandomAffine(Transform): ...@@ -383,7 +384,7 @@ class RandomAffine(Transform):
else: else:
self.shear = shear self.shear = shear
self.interpolation = interpolation self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
if center is not None: if center is not None:
...@@ -546,7 +547,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -546,7 +547,7 @@ class RandomPerspective(_RandomApplyTransform):
self, self,
distortion_scale: float = 0.5, distortion_scale: float = 0.5,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
super().__init__(p=p) super().__init__(p=p)
...@@ -555,7 +556,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -555,7 +556,7 @@ class RandomPerspective(_RandomApplyTransform):
raise ValueError("Argument distortion_scale value should be between 0 and 1") raise ValueError("Argument distortion_scale value should be between 0 and 1")
self.distortion_scale = distortion_scale self.distortion_scale = distortion_scale
self.interpolation = interpolation self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
...@@ -608,13 +609,13 @@ class ElasticTransform(Transform): ...@@ -608,13 +609,13 @@ class ElasticTransform(Transform):
alpha: Union[float, Sequence[float]] = 50.0, alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0, sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0, fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2) self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
self.sigma = _setup_float_or_seq(sigma, "sigma", 2) self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
self.interpolation = interpolation self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
...@@ -760,13 +761,13 @@ class ScaleJitter(Transform): ...@@ -760,13 +761,13 @@ class ScaleJitter(Transform):
self, self,
target_size: Tuple[int, int], target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0), scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
): ):
super().__init__() super().__init__()
self.target_size = target_size self.target_size = target_size
self.scale_range = scale_range self.scale_range = scale_range
self.interpolation = interpolation self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
...@@ -788,13 +789,13 @@ class RandomShortestSize(Transform): ...@@ -788,13 +789,13 @@ class RandomShortestSize(Transform):
self, self,
min_size: Union[List[int], Tuple[int], int], min_size: Union[List[int], Tuple[int], int],
max_size: Optional[int] = None, max_size: Optional[int] = None,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
): ):
super().__init__() super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size self.max_size = max_size
self.interpolation = interpolation self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
...@@ -935,13 +936,13 @@ class RandomResize(Transform): ...@@ -935,13 +936,13 @@ class RandomResize(Transform):
self, self,
min_size: int, min_size: int,
max_size: int, max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> None: ) -> None:
super().__init__() super().__init__()
self.min_size = min_size self.min_size = min_size
self.max_size = max_size self.max_size = max_size
self.interpolation = interpolation self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
......
...@@ -9,6 +9,8 @@ import PIL.Image ...@@ -9,6 +9,8 @@ import PIL.Image
import torch import torch
from torch import Tensor from torch import Tensor
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from . import functional as F, InterpolationMode from . import functional as F, InterpolationMode
__all__ = ["StereoMatching"] __all__ = ["StereoMatching"]
...@@ -22,7 +24,7 @@ class StereoMatching(torch.nn.Module): ...@@ -22,7 +24,7 @@ class StereoMatching(torch.nn.Module):
resize_size: Optional[Tuple[int, ...]], resize_size: Optional[Tuple[int, ...]],
mean: Tuple[float, ...] = (0.5, 0.5, 0.5), mean: Tuple[float, ...] = (0.5, 0.5, 0.5),
std: Tuple[float, ...] = (0.5, 0.5, 0.5), std: Tuple[float, ...] = (0.5, 0.5, 0.5),
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -36,7 +38,7 @@ class StereoMatching(torch.nn.Module): ...@@ -36,7 +38,7 @@ class StereoMatching(torch.nn.Module):
self.mean = list(mean) self.mean = list(mean)
self.std = list(std) self.std = list(std)
self.interpolation = interpolation self.interpolation = _check_interpolation(interpolation)
self.use_gray_scale = use_gray_scale self.use_gray_scale = use_gray_scale
def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]: def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]:
......
...@@ -13,6 +13,7 @@ from torchvision.transforms.functional import ( ...@@ -13,6 +13,7 @@ from torchvision.transforms.functional import (
_check_antialias, _check_antialias,
_compute_resized_output_size as __compute_resized_output_size, _compute_resized_output_size as __compute_resized_output_size,
_get_perspective_coeffs, _get_perspective_coeffs,
_interpolation_modes_from_int,
InterpolationMode, InterpolationMode,
pil_modes_mapping, pil_modes_mapping,
pil_to_tensor, pil_to_tensor,
...@@ -27,6 +28,17 @@ from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_ ...@@ -27,6 +28,17 @@ from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_
from ._utils import is_simple_tensor from ._utils import is_simple_tensor
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise ValueError(
f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
f"but got {interpolation}."
)
return interpolation
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1) return image.flip(-1)
...@@ -142,10 +154,11 @@ def _compute_resized_output_size( ...@@ -142,10 +154,11 @@ def _compute_resized_output_size(
def resize_image_tensor( def resize_image_tensor(
image: torch.Tensor, image: torch.Tensor,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor: ) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation) antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation)
assert not isinstance(antialias, str) assert not isinstance(antialias, str)
antialias = False if antialias is None else antialias antialias = False if antialias is None else antialias
...@@ -189,9 +202,10 @@ def resize_image_tensor( ...@@ -189,9 +202,10 @@ def resize_image_tensor(
def resize_image_pil( def resize_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
size: Union[Sequence[int], int], size: Union[Sequence[int], int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
interpolation = _check_interpolation(interpolation)
size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size) # type: ignore[arg-type] size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size) # type: ignore[arg-type]
return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation]) return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation])
...@@ -228,7 +242,7 @@ def resize_bounding_box( ...@@ -228,7 +242,7 @@ def resize_bounding_box(
def resize_video( def resize_video(
video: torch.Tensor, video: torch.Tensor,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -238,7 +252,7 @@ def resize_video( ...@@ -238,7 +252,7 @@ def resize_video(
def resize( def resize(
inpt: datapoints.InputTypeJIT, inpt: datapoints.InputTypeJIT,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
...@@ -513,10 +527,12 @@ def affine_image_tensor( ...@@ -513,10 +527,12 @@ def affine_image_tensor(
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
if image.numel() == 0: if image.numel() == 0:
return image return image
...@@ -563,10 +579,11 @@ def affine_image_pil( ...@@ -563,10 +579,11 @@ def affine_image_pil(
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
interpolation = _check_interpolation(interpolation)
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
...@@ -731,7 +748,7 @@ def affine_video( ...@@ -731,7 +748,7 @@ def affine_video(
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -753,7 +770,7 @@ def affine( ...@@ -753,7 +770,7 @@ def affine(
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
...@@ -797,11 +814,13 @@ def affine( ...@@ -797,11 +814,13 @@ def affine(
def rotate_image_tensor( def rotate_image_tensor(
image: torch.Tensor, image: torch.Tensor,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
shape = image.shape shape = image.shape
num_channels, height, width = shape[-3:] num_channels, height, width = shape[-3:]
...@@ -840,11 +859,13 @@ def rotate_image_tensor( ...@@ -840,11 +859,13 @@ def rotate_image_tensor(
def rotate_image_pil( def rotate_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
interpolation = _check_interpolation(interpolation)
if center is not None and expand: if center is not None and expand:
warnings.warn("The provided center argument has no effect on the result if expand is True") warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None center = None
...@@ -910,7 +931,7 @@ def rotate_mask( ...@@ -910,7 +931,7 @@ def rotate_mask(
def rotate_video( def rotate_video(
video: torch.Tensor, video: torch.Tensor,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
...@@ -921,7 +942,7 @@ def rotate_video( ...@@ -921,7 +942,7 @@ def rotate_video(
def rotate( def rotate(
inpt: datapoints.InputTypeJIT, inpt: datapoints.InputTypeJIT,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
...@@ -1281,11 +1302,13 @@ def perspective_image_tensor( ...@@ -1281,11 +1302,13 @@ def perspective_image_tensor(
image: torch.Tensor, image: torch.Tensor,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
interpolation = _check_interpolation(interpolation)
if image.numel() == 0: if image.numel() == 0:
return image return image
...@@ -1326,11 +1349,12 @@ def perspective_image_pil( ...@@ -1326,11 +1349,12 @@ def perspective_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BICUBIC, interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
interpolation = _check_interpolation(interpolation)
return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
...@@ -1455,7 +1479,7 @@ def perspective_video( ...@@ -1455,7 +1479,7 @@ def perspective_video(
video: torch.Tensor, video: torch.Tensor,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -1468,7 +1492,7 @@ def perspective( ...@@ -1468,7 +1492,7 @@ def perspective(
inpt: datapoints.InputTypeJIT, inpt: datapoints.InputTypeJIT,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
...@@ -1496,9 +1520,11 @@ def perspective( ...@@ -1496,9 +1520,11 @@ def perspective(
def elastic_image_tensor( def elastic_image_tensor(
image: torch.Tensor, image: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
if image.numel() == 0: if image.numel() == 0:
return image return image
...@@ -1537,7 +1563,7 @@ def elastic_image_tensor( ...@@ -1537,7 +1563,7 @@ def elastic_image_tensor(
def elastic_image_pil( def elastic_image_pil(
image: PIL.Image.Image, image: PIL.Image.Image,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
t_img = pil_to_tensor(image) t_img = pil_to_tensor(image)
...@@ -1630,7 +1656,7 @@ def elastic_mask( ...@@ -1630,7 +1656,7 @@ def elastic_mask(
def elastic_video( def elastic_video(
video: torch.Tensor, video: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
...@@ -1639,7 +1665,7 @@ def elastic_video( ...@@ -1639,7 +1665,7 @@ def elastic_video(
def elastic( def elastic(
inpt: datapoints.InputTypeJIT, inpt: datapoints.InputTypeJIT,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None, fill: datapoints.FillTypeJIT = None,
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
...@@ -1778,7 +1804,7 @@ def resized_crop_image_tensor( ...@@ -1778,7 +1804,7 @@ def resized_crop_image_tensor(
height: int, height: int,
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor: ) -> torch.Tensor:
image = crop_image_tensor(image, top, left, height, width) image = crop_image_tensor(image, top, left, height, width)
...@@ -1793,7 +1819,7 @@ def resized_crop_image_pil( ...@@ -1793,7 +1819,7 @@ def resized_crop_image_pil(
height: int, height: int,
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
image = crop_image_pil(image, top, left, height, width) image = crop_image_pil(image, top, left, height, width)
return resize_image_pil(image, size, interpolation=interpolation) return resize_image_pil(image, size, interpolation=interpolation)
...@@ -1831,7 +1857,7 @@ def resized_crop_video( ...@@ -1831,7 +1857,7 @@ def resized_crop_video(
height: int, height: int,
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor: ) -> torch.Tensor:
return resized_crop_image_tensor( return resized_crop_image_tensor(
...@@ -1846,7 +1872,7 @@ def resized_crop( ...@@ -1846,7 +1872,7 @@ def resized_crop(
height: int, height: int,
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints.InputTypeJIT: ) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment