Unverified Commit 0e0a5dc7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Support integer values for interpolation in the prototype transforms (#7248)

parent f627b9d1
......@@ -1534,7 +1534,7 @@ class TestScaleJitter:
assert int(spatial_size[1] * r_min) <= width <= int(spatial_size[1] * r_max)
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()
transform = transforms.ScaleJitter(
......@@ -1581,7 +1581,7 @@ class TestRandomShortestSize:
assert shorter in min_size
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()
transform = transforms.RandomShortestSize(
......@@ -1945,7 +1945,7 @@ class TestRandomResize:
assert min_size <= size < max_size
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()
transform = transforms.RandomResize(
......
......@@ -88,6 +88,9 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC),
NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100),
......@@ -305,6 +308,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(25, ratio=(0.5, 1.5)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
......@@ -352,6 +357,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(sigma=(2.5, 3.9)),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs(interpolation=PIL.Image.NEAREST),
ArgsKwargs(interpolation=PIL.Image.BICUBIC),
ArgsKwargs(fill=1),
],
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
......@@ -386,6 +393,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
ArgsKwargs(degrees=30.0, center=(0, 0)),
......@@ -420,6 +428,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=1),
ArgsKwargs(p=1, distortion_scale=0.3),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
],
......@@ -432,6 +441,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
ArgsKwargs(degrees=30.0, expand=True),
ArgsKwargs(degrees=30.0, center=(0, 0)),
ArgsKwargs(degrees=30.0, fill=1),
......@@ -851,7 +861,11 @@ class TestAATransforms:
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_randaug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
......@@ -889,7 +903,11 @@ class TestAATransforms:
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_trivial_aug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
......@@ -937,7 +955,11 @@ class TestAATransforms:
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_augmix(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
......@@ -986,7 +1008,11 @@ class TestAATransforms:
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_aa(self, inpt, interpolation):
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
......@@ -1264,13 +1290,13 @@ class TestRefSegTransforms:
(legacy_F.convert_image_dtype, {}),
(legacy_F.to_pil_image, {}),
(legacy_F.normalize, {}),
(legacy_F.resize, {}),
(legacy_F.resize, {"interpolation"}),
(legacy_F.pad, {"padding", "fill"}),
(legacy_F.crop, {}),
(legacy_F.center_crop, {}),
(legacy_F.resized_crop, {}),
(legacy_F.resized_crop, {"interpolation"}),
(legacy_F.hflip, {}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill"}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
(legacy_F.vflip, {}),
(legacy_F.five_crop, {}),
(legacy_F.ten_crop, {}),
......@@ -1279,8 +1305,8 @@ class TestRefSegTransforms:
(legacy_F.adjust_saturation, {}),
(legacy_F.adjust_hue, {}),
(legacy_F.adjust_gamma, {}),
(legacy_F.rotate, {"center", "fill"}),
(legacy_F.affine, {"angle", "translate", "center", "fill"}),
(legacy_F.rotate, {"center", "fill", "interpolation"}),
(legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
(legacy_F.to_grayscale, {}),
(legacy_F.rgb_to_grayscale, {}),
(legacy_F.to_tensor, {}),
......@@ -1292,7 +1318,7 @@ class TestRefSegTransforms:
(legacy_F.adjust_sharpness, {}),
(legacy_F.autocontrast, {}),
(legacy_F.equalize, {}),
(legacy_F.elastic_transform, {"fill"}),
(legacy_F.elastic_transform, {"fill", "interpolation"}),
],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
......
......@@ -76,7 +76,7 @@ class BoundingBox(Datapoint):
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox:
......@@ -107,7 +107,7 @@ class BoundingBox(Datapoint):
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox:
output, spatial_size = self._F.resized_crop_bounding_box(
......@@ -133,7 +133,7 @@ class BoundingBox(Datapoint):
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
......@@ -154,7 +154,7 @@ class BoundingBox(Datapoint):
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
......@@ -174,7 +174,7 @@ class BoundingBox(Datapoint):
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> BoundingBox:
......@@ -191,7 +191,7 @@ class BoundingBox(Datapoint):
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(
......
......@@ -143,7 +143,7 @@ class Datapoint(torch.Tensor):
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
......@@ -162,7 +162,7 @@ class Datapoint(torch.Tensor):
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
return self
......@@ -178,7 +178,7 @@ class Datapoint(torch.Tensor):
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
......@@ -191,7 +191,7 @@ class Datapoint(torch.Tensor):
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Datapoint:
......@@ -201,7 +201,7 @@ class Datapoint(torch.Tensor):
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Datapoint:
......@@ -210,7 +210,7 @@ class Datapoint(torch.Tensor):
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Datapoint:
return self
......
......@@ -62,7 +62,7 @@ class Image(Datapoint):
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
......@@ -86,7 +86,7 @@ class Image(Datapoint):
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
output = self._F.resized_crop_image_tensor(
......@@ -113,7 +113,7 @@ class Image(Datapoint):
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
......@@ -129,7 +129,7 @@ class Image(Datapoint):
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
......@@ -149,7 +149,7 @@ class Image(Datapoint):
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Image:
......@@ -166,7 +166,7 @@ class Image(Datapoint):
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Image:
output = self._F.elastic_image_tensor(
......
......@@ -53,7 +53,7 @@ class Mask(Datapoint):
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
......@@ -75,7 +75,7 @@ class Mask(Datapoint):
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
......@@ -93,7 +93,7 @@ class Mask(Datapoint):
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
......@@ -107,7 +107,7 @@ class Mask(Datapoint):
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
......@@ -126,7 +126,7 @@ class Mask(Datapoint):
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Mask:
......@@ -138,7 +138,7 @@ class Mask(Datapoint):
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
......
......@@ -57,7 +57,7 @@ class Video(Datapoint):
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Video:
......@@ -85,7 +85,7 @@ class Video(Datapoint):
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Video:
output = self._F.resized_crop_video(
......@@ -112,7 +112,7 @@ class Video(Datapoint):
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
......@@ -128,7 +128,7 @@ class Video(Datapoint):
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Video:
......@@ -148,7 +148,7 @@ class Video(Datapoint):
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Video:
......@@ -165,7 +165,7 @@ class Video(Datapoint):
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Video:
output = self._F.elastic_video(
......
......@@ -10,6 +10,7 @@ from torchvision import transforms as _transforms
from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from ._transform import _RandomApplyTransform
from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size
......@@ -203,11 +204,11 @@ class SimpleCopyPaste(Transform):
def __init__(
self,
blending: bool = True,
resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR,
resize_interpolation: Union[int, InterpolationMode] = F.InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> None:
super().__init__()
self.resize_interpolation = resize_interpolation
self.resize_interpolation = _check_interpolation(resize_interpolation)
self.blending = blending
self.antialias = antialias
......
......@@ -8,6 +8,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from torchvision.prototype.transforms.functional._meta import get_spatial_size
from torchvision.transforms import functional_tensor as _FT
......@@ -19,11 +20,11 @@ class _AutoAugmentBase(Transform):
def __init__(
self,
*,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None:
super().__init__()
self.interpolation = interpolation
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
def _get_random_item(self, dct: Dict[str, Tuple[Callable, bool]]) -> Tuple[str, Tuple[Callable, bool]]:
......@@ -79,7 +80,7 @@ class _AutoAugmentBase(Transform):
image: Union[datapoints.ImageType, datapoints.VideoType],
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
interpolation: Union[InterpolationMode, int],
fill: Dict[Type, datapoints.FillTypeJIT],
) -> Union[datapoints.ImageType, datapoints.VideoType]:
fill_ = fill[type(image)]
......@@ -193,7 +194,7 @@ class AutoAugment(_AutoAugmentBase):
def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
......@@ -350,7 +351,7 @@ class RandAugment(_AutoAugmentBase):
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
......@@ -403,7 +404,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
......@@ -461,7 +462,7 @@ class AugMix(_AutoAugmentBase):
chain_depth: int = -1,
alpha: float = 1.0,
all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
......
......@@ -10,6 +10,7 @@ from torchvision import transforms as _transforms
from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from torchvision.transforms.functional import _get_perspective_coeffs
from ._transform import _RandomApplyTransform
......@@ -45,7 +46,7 @@ class Resize(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
......@@ -61,7 +62,7 @@ class Resize(Transform):
)
self.size = size
self.interpolation = interpolation
self.interpolation = _check_interpolation(interpolation)
self.max_size = max_size
self.antialias = antialias
......@@ -94,7 +95,7 @@ class RandomResizedCrop(Transform):
size: Union[int, Sequence[int]],
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
......@@ -111,7 +112,7 @@ class RandomResizedCrop(Transform):
self.scale = scale
self.ratio = ratio
self.interpolation = interpolation
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
self._log_ratio = torch.log(torch.tensor(self.ratio))
......@@ -317,14 +318,14 @@ class RandomRotation(Transform):
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
self.interpolation = interpolation
self.interpolation = _check_interpolation(interpolation)
self.expand = expand
self.fill = _setup_fill_arg(fill)
......@@ -359,7 +360,7 @@ class RandomAffine(Transform):
translate: Optional[Sequence[float]] = None,
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
......@@ -383,7 +384,7 @@ class RandomAffine(Transform):
else:
self.shear = shear
self.interpolation = interpolation
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
if center is not None:
......@@ -546,7 +547,7 @@ class RandomPerspective(_RandomApplyTransform):
self,
distortion_scale: float = 0.5,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
p: float = 0.5,
) -> None:
super().__init__(p=p)
......@@ -555,7 +556,7 @@ class RandomPerspective(_RandomApplyTransform):
raise ValueError("Argument distortion_scale value should be between 0 and 1")
self.distortion_scale = distortion_scale
self.interpolation = interpolation
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
......@@ -608,13 +609,13 @@ class ElasticTransform(Transform):
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[datapoints.FillType, Dict[Type, datapoints.FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
self.interpolation = interpolation
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
......@@ -760,13 +761,13 @@ class ScaleJitter(Transform):
self,
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
......@@ -788,13 +789,13 @@ class RandomShortestSize(Transform):
self,
min_size: Union[List[int], Tuple[int], int],
max_size: Optional[int] = None,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = interpolation
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
......@@ -935,13 +936,13 @@ class RandomResize(Transform):
self,
min_size: int,
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> None:
super().__init__()
self.min_size = min_size
self.max_size = max_size
self.interpolation = interpolation
self.interpolation = _check_interpolation(interpolation)
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
......
......@@ -9,6 +9,8 @@ import PIL.Image
import torch
from torch import Tensor
from torchvision.prototype.transforms.functional._geometry import _check_interpolation
from . import functional as F, InterpolationMode
__all__ = ["StereoMatching"]
......@@ -22,7 +24,7 @@ class StereoMatching(torch.nn.Module):
resize_size: Optional[Tuple[int, ...]],
mean: Tuple[float, ...] = (0.5, 0.5, 0.5),
std: Tuple[float, ...] = (0.5, 0.5, 0.5),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
......@@ -36,7 +38,7 @@ class StereoMatching(torch.nn.Module):
self.mean = list(mean)
self.std = list(std)
self.interpolation = interpolation
self.interpolation = _check_interpolation(interpolation)
self.use_gray_scale = use_gray_scale
def forward(self, left_image: Tensor, right_image: Tensor) -> Tuple[Tensor, Tensor]:
......
......@@ -13,6 +13,7 @@ from torchvision.transforms.functional import (
_check_antialias,
_compute_resized_output_size as __compute_resized_output_size,
_get_perspective_coeffs,
_interpolation_modes_from_int,
InterpolationMode,
pil_modes_mapping,
pil_to_tensor,
......@@ -27,6 +28,17 @@ from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_
from ._utils import is_simple_tensor
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise ValueError(
f"Argument interpolation should be an `InterpolationMode` or a corresponding Pillow integer constant, "
f"but got {interpolation}."
)
return interpolation
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1)
......@@ -142,10 +154,11 @@ def _compute_resized_output_size(
def resize_image_tensor(
image: torch.Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
antialias = _check_antialias(img=image, antialias=antialias, interpolation=interpolation)
assert not isinstance(antialias, str)
antialias = False if antialias is None else antialias
......@@ -189,9 +202,10 @@ def resize_image_tensor(
def resize_image_pil(
image: PIL.Image.Image,
size: Union[Sequence[int], int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
) -> PIL.Image.Image:
interpolation = _check_interpolation(interpolation)
size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size) # type: ignore[arg-type]
return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation])
......@@ -228,7 +242,7 @@ def resize_bounding_box(
def resize_video(
video: torch.Tensor,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor:
......@@ -238,7 +252,7 @@ def resize_video(
def resize(
inpt: datapoints.InputTypeJIT,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints.InputTypeJIT:
......@@ -513,10 +527,12 @@ def affine_image_tensor(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
if image.numel() == 0:
return image
......@@ -563,10 +579,11 @@ def affine_image_pil(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
interpolation = _check_interpolation(interpolation)
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
......@@ -731,7 +748,7 @@ def affine_video(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
......@@ -753,7 +770,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints.FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> datapoints.InputTypeJIT:
......@@ -797,11 +814,13 @@ def affine(
def rotate_image_tensor(
image: torch.Tensor,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: datapoints.FillTypeJIT = None,
) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
shape = image.shape
num_channels, height, width = shape[-3:]
......@@ -840,11 +859,13 @@ def rotate_image_tensor(
def rotate_image_pil(
image: PIL.Image.Image,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: datapoints.FillTypeJIT = None,
) -> PIL.Image.Image:
interpolation = _check_interpolation(interpolation)
if center is not None and expand:
warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None
......@@ -910,7 +931,7 @@ def rotate_mask(
def rotate_video(
video: torch.Tensor,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: datapoints.FillTypeJIT = None,
......@@ -921,7 +942,7 @@ def rotate_video(
def rotate(
inpt: datapoints.InputTypeJIT,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: datapoints.FillTypeJIT = None,
......@@ -1281,11 +1302,13 @@ def perspective_image_tensor(
image: torch.Tensor,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> torch.Tensor:
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
interpolation = _check_interpolation(interpolation)
if image.numel() == 0:
return image
......@@ -1326,11 +1349,12 @@ def perspective_image_pil(
image: PIL.Image.Image,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BICUBIC,
fill: datapoints.FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> PIL.Image.Image:
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
interpolation = _check_interpolation(interpolation)
return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
......@@ -1455,7 +1479,7 @@ def perspective_video(
video: torch.Tensor,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> torch.Tensor:
......@@ -1468,7 +1492,7 @@ def perspective(
inpt: datapoints.InputTypeJIT,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> datapoints.InputTypeJIT:
......@@ -1496,9 +1520,11 @@ def perspective(
def elastic_image_tensor(
image: torch.Tensor,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None,
) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
if image.numel() == 0:
return image
......@@ -1537,7 +1563,7 @@ def elastic_image_tensor(
def elastic_image_pil(
image: PIL.Image.Image,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None,
) -> PIL.Image.Image:
t_img = pil_to_tensor(image)
......@@ -1630,7 +1656,7 @@ def elastic_mask(
def elastic_video(
video: torch.Tensor,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None,
) -> torch.Tensor:
return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
......@@ -1639,7 +1665,7 @@ def elastic_video(
def elastic(
inpt: datapoints.InputTypeJIT,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints.FillTypeJIT = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
......@@ -1778,7 +1804,7 @@ def resized_crop_image_tensor(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor:
image = crop_image_tensor(image, top, left, height, width)
......@@ -1793,7 +1819,7 @@ def resized_crop_image_pil(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
) -> PIL.Image.Image:
image = crop_image_pil(image, top, left, height, width)
return resize_image_pil(image, size, interpolation=interpolation)
......@@ -1831,7 +1857,7 @@ def resized_crop_video(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> torch.Tensor:
return resized_crop_image_tensor(
......@@ -1846,7 +1872,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment