Unverified Commit 0c445130 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[BC-breaking] Introduced InterpolationModes and deprecated arguments: resample...

[BC-breaking] Introduced InterpolationModes and deprecated arguments: resample and fillcolor (#2952)

* Deprecated arguments: resample and fillcolor
Replaced by interpolation and fill

* Updates according to the review

* Added tests to check warnings and asserted BC

* [WIP] Interpolation modes

* Added InterpolationModes enum

* Added supported for int values for interpolation for BC

* Removed useless test code

* Fix flake8
parent 240210c9
...@@ -4,16 +4,19 @@ import colorsys ...@@ -4,16 +4,19 @@ import colorsys
import math import math
import numpy as np import numpy as np
from PIL.Image import NEAREST, BILINEAR, BICUBIC
import torch import torch
import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationModes
from common_utils import TransformsTester from common_utils import TransformsTester
NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC
class Tester(TransformsTester): class Tester(TransformsTester):
def setUp(self): def setUp(self):
...@@ -365,7 +368,7 @@ class Tester(TransformsTester): ...@@ -365,7 +368,7 @@ class Tester(TransformsTester):
) )
def test_resize(self): def test_resize(self):
script_fn = torch.jit.script(F_t.resize) script_fn = torch.jit.script(F.resize)
tensor, pil_img = self._create_data(26, 36, device=self.device) tensor, pil_img = self._create_data(26, 36, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
...@@ -382,14 +385,14 @@ class Tester(TransformsTester): ...@@ -382,14 +385,14 @@ class Tester(TransformsTester):
for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]: for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]: for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation) resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
resized_pil_img = F_pil.resize(pil_img, size=size, interpolation=interpolation) resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
self.assertEqual( self.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation) resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation)
) )
if interpolation != NEAREST: if interpolation not in [NEAREST, ]:
# We can not check values if mode = NEAREST, as results are different # We can not check values if mode = NEAREST, as results are different
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
...@@ -407,6 +410,7 @@ class Tester(TransformsTester): ...@@ -407,6 +410,7 @@ class Tester(TransformsTester):
script_size = [size, ] script_size = [size, ]
else: else:
script_size = size script_size = size
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation) resize_result = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
...@@ -414,17 +418,24 @@ class Tester(TransformsTester): ...@@ -414,17 +418,24 @@ class Tester(TransformsTester):
batch_tensors, F.resize, size=script_size, interpolation=interpolation batch_tensors, F.resize, size=script_size, interpolation=interpolation
) )
# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
res1 = F.resize(tensor, size=32, interpolation=2)
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
def test_resized_crop(self): def test_resized_crop(self):
# test values of F.resized_crop in several cases: # test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity # 1) resize to the same size, crop to the same size => should be identity
tensor, _ = self._create_data(26, 36, device=self.device) tensor, _ = self._create_data(26, 36, device=self.device)
for i in [0, 2, 3]:
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=i) for mode in [NEAREST, BILINEAR, BICUBIC]:
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
# 2) resize by half and crop a TL corner # 2) resize by half and crop a TL corner
tensor, _ = self._create_data(26, 36, device=self.device) tensor, _ = self._create_data(26, 36, device=self.device)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=0) out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
expected_out_tensor = tensor[:, :20:2, :30:2] expected_out_tensor = tensor[:, :20:2, :30:2]
self.assertTrue( self.assertTrue(
expected_out_tensor.equal(out_tensor), expected_out_tensor.equal(out_tensor),
...@@ -433,17 +444,19 @@ class Tester(TransformsTester): ...@@ -433,17 +444,19 @@ class Tester(TransformsTester):
batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
self._test_fn_on_batch( self._test_fn_on_batch(
batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=0 batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST
) )
def _test_affine_identity_map(self, tensor, scripted_affine): def _test_affine_identity_map(self, tensor, scripted_affine):
# 1) identity map # 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
self.assertTrue( self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
) )
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) out_tensor = scripted_affine(
tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
)
self.assertTrue( self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
) )
...@@ -461,13 +474,13 @@ class Tester(TransformsTester): ...@@ -461,13 +474,13 @@ class Tester(TransformsTester):
] ]
for a, true_tensor in test_configs: for a, true_tensor in test_configs:
out_pil_img = F.affine( out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
) )
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device)
for fn in [F.affine, scripted_affine]: for fn in [F.affine, scripted_affine]:
out_tensor = fn( out_tensor = fn(
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
) )
if true_tensor is not None: if true_tensor is not None:
self.assertTrue( self.assertTrue(
...@@ -496,13 +509,13 @@ class Tester(TransformsTester): ...@@ -496,13 +509,13 @@ class Tester(TransformsTester):
for a in test_configs: for a in test_configs:
out_pil_img = F.affine( out_pil_img = F.affine(
pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
) )
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.affine, scripted_affine]: for fn in [F.affine, scripted_affine]:
out_tensor = fn( out_tensor = fn(
tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
).cpu() ).cpu()
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
...@@ -526,10 +539,10 @@ class Tester(TransformsTester): ...@@ -526,10 +539,10 @@ class Tester(TransformsTester):
] ]
for t in test_configs: for t in test_configs:
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
for fn in [F.affine, scripted_affine]: for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8) out_tensor = out_tensor.to(torch.uint8)
...@@ -550,13 +563,13 @@ class Tester(TransformsTester): ...@@ -550,13 +563,13 @@ class Tester(TransformsTester):
(-45, [-10, -10], 1.2, [4.0, 5.0]), (-45, [-10, -10], 1.2, [4.0, 5.0]),
(-90, [0, 0], 1.0, [0.0, 0.0]), (-90, [0, 0], 1.0, [0.0, 0.0]),
] ]
for r in [0, ]: for r in [NEAREST, ]:
for a, t, s, sh in test_configs: for a, t, s, sh in test_configs:
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.affine, scripted_affine]: for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu() out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r).cpu()
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8) out_tensor = out_tensor.to(torch.uint8)
...@@ -605,18 +618,36 @@ class Tester(TransformsTester): ...@@ -605,18 +618,36 @@ class Tester(TransformsTester):
batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0] batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]
) )
tensor, pil_img = data[0]
# assert deprecation warning and non-BC
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10)
res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10)
self.assertEqual(res1, res2)
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers): def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
img_size = pil_img.size img_size = pil_img.size
dt = tensor.dtype dt = tensor.dtype
for r in [0, ]: for r in [NEAREST, ]:
for a in range(-180, 180, 17): for a in range(-180, 180, 17):
for e in [True, False]: for e in [True, False]:
for c in centers: for c in centers:
out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]: for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu() out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c).cpu()
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8) out_tensor = out_tensor.to(torch.uint8)
...@@ -673,12 +704,24 @@ class Tester(TransformsTester): ...@@ -673,12 +704,24 @@ class Tester(TransformsTester):
center = (20, 22) center = (20, 22)
self._test_fn_on_batch( self._test_fn_on_batch(
batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
) )
tensor, pil_img = data[0]
# assert deprecation warning and non-BC
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
res1 = F.rotate(tensor, 45, resample=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
res1 = F.rotate(tensor, 45, interpolation=2)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs): def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
dt = tensor.dtype dt = tensor.dtype
for r in [0, ]: for r in [NEAREST, ]:
for spoints, epoints in test_configs: for spoints, epoints in test_configs:
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
...@@ -739,9 +782,17 @@ class Tester(TransformsTester): ...@@ -739,9 +782,17 @@ class Tester(TransformsTester):
for spoints, epoints in test_configs: for spoints, epoints in test_configs:
self._test_fn_on_batch( self._test_fn_on_batch(
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0 batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=NEAREST
) )
# assert changed type warning
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2)
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
def test_gaussian_blur(self): def test_gaussian_blur(self):
small_image_tensor = torch.from_numpy( small_image_tensor = torch.from_numpy(
np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
......
...@@ -1492,11 +1492,21 @@ class Tester(unittest.TestCase): ...@@ -1492,11 +1492,21 @@ class Tester(unittest.TestCase):
t = transforms.RandomRotation((-10, 10)) t = transforms.RandomRotation((-10, 10))
angle = t.get_params(t.degrees) angle = t.get_params(t.degrees)
self.assertTrue(angle > -10 and angle < 10) self.assertTrue(-10 < angle < 10)
# Checking if RandomRotation can be printed as string # Checking if RandomRotation can be printed as string
t.__repr__() t.__repr__()
# assert deprecation warning and non-BC
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
t = transforms.RandomRotation((-10, 10), resample=2)
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
t = transforms.RandomRotation((-10, 10), interpolation=2)
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
def test_random_affine(self): def test_random_affine(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -1537,8 +1547,22 @@ class Tester(unittest.TestCase): ...@@ -1537,8 +1547,22 @@ class Tester(unittest.TestCase):
# Checking if RandomAffine can be printed as string # Checking if RandomAffine can be printed as string
t.__repr__() t.__repr__()
t = transforms.RandomAffine(10, resample=Image.BILINEAR) t = transforms.RandomAffine(10, interpolation=transforms.InterpolationModes.BILINEAR)
self.assertIn("Image.BILINEAR", t.__repr__()) self.assertIn("bilinear", t.__repr__())
# assert deprecation warning and non-BC
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
t = transforms.RandomAffine(10, resample=2)
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
t = transforms.RandomAffine(10, fillcolor=10)
self.assertEqual(t.fill, 10)
# assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"):
t = transforms.RandomAffine(10, interpolation=2)
self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR)
def test_to_grayscale(self): def test_to_grayscale(self):
"""Unit tests for grayscale transform""" """Unit tests for grayscale transform"""
......
...@@ -2,8 +2,7 @@ import os ...@@ -2,8 +2,7 @@ import os
import torch import torch
from torchvision import transforms as T from torchvision import transforms as T
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationModes
from PIL.Image import NEAREST, BILINEAR, BICUBIC
import numpy as np import numpy as np
...@@ -12,6 +11,9 @@ import unittest ...@@ -12,6 +11,9 @@ import unittest
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC
class Tester(TransformsTester): class Tester(TransformsTester):
def setUp(self): def setUp(self):
...@@ -349,7 +351,7 @@ class Tester(TransformsTester): ...@@ -349,7 +351,7 @@ class Tester(TransformsTester):
for interpolation in [NEAREST, BILINEAR]: for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomAffine( transform = T.RandomAffine(
degrees=degrees, translate=translate, degrees=degrees, translate=translate,
scale=scale, shear=shear, resample=interpolation scale=scale, shear=shear, interpolation=interpolation
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
...@@ -368,7 +370,7 @@ class Tester(TransformsTester): ...@@ -368,7 +370,7 @@ class Tester(TransformsTester):
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]: for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
for interpolation in [NEAREST, BILINEAR]: for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomRotation( transform = T.RandomRotation(
degrees=degrees, resample=interpolation, expand=expand, center=center degrees=degrees, interpolation=interpolation, expand=expand, center=center
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
......
import math import math
import numbers import numbers
import warnings import warnings
from enum import Enum
from typing import Any, Optional from typing import Any, Optional
import numpy as np import numpy as np
...@@ -19,6 +20,41 @@ from . import functional_pil as F_pil ...@@ -19,6 +20,41 @@ from . import functional_pil as F_pil
from . import functional_tensor as F_t from . import functional_tensor as F_t
class InterpolationModes(Enum):
"""Interpolation modes
"""
NEAREST = "nearest"
BILINEAR = "bilinear"
BICUBIC = "bicubic"
# For PIL compatibility
BOX = "box"
HAMMING = "hamming"
LANCZOS = "lanczos"
# TODO: Once torchscript supports Enums with staticmethod
# this can be put into InterpolationModes as staticmethod
def _interpolation_modes_from_int(i: int) -> InterpolationModes:
inverse_modes_mapping = {
0: InterpolationModes.NEAREST,
2: InterpolationModes.BILINEAR,
3: InterpolationModes.BICUBIC,
4: InterpolationModes.BOX,
5: InterpolationModes.HAMMING,
1: InterpolationModes.LANCZOS,
}
return inverse_modes_mapping[i]
pil_modes_mapping = {
InterpolationModes.NEAREST: 0,
InterpolationModes.BILINEAR: 2,
InterpolationModes.BICUBIC: 3,
InterpolationModes.BOX: 4,
InterpolationModes.HAMMING: 5,
InterpolationModes.LANCZOS: 1,
}
_is_pil_image = F_pil._is_pil_image _is_pil_image = F_pil._is_pil_image
_parse_fill = F_pil._parse_fill _parse_fill = F_pil._parse_fill
...@@ -293,7 +329,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool ...@@ -293,7 +329,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
return tensor return tensor
def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) -> Tensor: def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = InterpolationModes.BILINEAR) -> Tensor:
r"""Resize the input image to the given size. r"""Resize the input image to the given size.
The image can be a PIL Image or a torch Tensor, in which case it is expected The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
...@@ -307,17 +343,31 @@ def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) -> ...@@ -307,17 +343,31 @@ def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) ->
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
In torchscript mode size as single int is not supported, use a tuple or In torchscript mode size as single int is not supported, use a tuple or
list of length 1: ``[size, ]``. list of length 1: ``[size, ]``.
interpolation (int, optional): Desired interpolation enum defined by `filters`_. interpolation (InterpolationModes): Desired interpolation enum defined by
Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` :class:`torchvision.transforms.InterpolationModes`.
and ``PIL.Image.BICUBIC`` are supported. Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``,
``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
Returns: Returns:
PIL Image or Tensor: Resized image. PIL Image or Tensor: Resized image.
""" """
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(interpolation, InterpolationModes):
raise TypeError("Argument interpolation should be a InterpolationModes")
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.resize(img, size=size, interpolation=interpolation) pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=size, interpolation=pil_interpolation)
return F_t.resize(img, size=size, interpolation=interpolation) return F_t.resize(img, size=size, interpolation=interpolation.value)
def scale(*args, **kwargs): def scale(*args, **kwargs):
...@@ -424,7 +474,8 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: ...@@ -424,7 +474,8 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
def resized_crop( def resized_crop(
img: Tensor, top: int, left: int, height: int, width: int, size: List[int], interpolation: int = Image.BILINEAR img: Tensor, top: int, left: int, height: int, width: int, size: List[int],
interpolation: InterpolationModes = InterpolationModes.BILINEAR
) -> Tensor: ) -> Tensor:
"""Crop the given image and resize it to desired size. """Crop the given image and resize it to desired size.
The image can be a PIL Image or a Tensor, in which case it is expected The image can be a PIL Image or a Tensor, in which case it is expected
...@@ -439,9 +490,12 @@ def resized_crop( ...@@ -439,9 +490,12 @@ def resized_crop(
height (int): Height of the crop box. height (int): Height of the crop box.
width (int): Width of the crop box. width (int): Width of the crop box.
size (sequence or int): Desired output size. Same semantics as ``resize``. size (sequence or int): Desired output size. Same semantics as ``resize``.
interpolation (int, optional): Desired interpolation enum defined by `filters`_. interpolation (InterpolationModes): Desired interpolation enum defined by
Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` :class:`torchvision.transforms.InterpolationModes`.
and ``PIL.Image.BICUBIC`` are supported. Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``,
``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
Returns: Returns:
PIL Image or Tensor: Cropped image. PIL Image or Tensor: Cropped image.
""" """
...@@ -502,7 +556,7 @@ def perspective( ...@@ -502,7 +556,7 @@ def perspective(
img: Tensor, img: Tensor,
startpoints: List[List[int]], startpoints: List[List[int]],
endpoints: List[List[int]], endpoints: List[List[int]],
interpolation: int = 2, interpolation: InterpolationModes = InterpolationModes.BILINEAR,
fill: Optional[int] = None fill: Optional[int] = None
) -> Tensor: ) -> Tensor:
"""Perform perspective transform of the given image. """Perform perspective transform of the given image.
...@@ -515,8 +569,10 @@ def perspective( ...@@ -515,8 +569,10 @@ def perspective(
``[top-left, top-right, bottom-right, bottom-left]`` of the original image. ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image. ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and interpolation (InterpolationModes): Desired interpolation enum defined by
``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``.
If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively. image. If int or float, the value is used for all bands respectively.
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor
...@@ -528,10 +584,22 @@ def perspective( ...@@ -528,10 +584,22 @@ def perspective(
coeffs = _get_perspective_coeffs(startpoints, endpoints) coeffs = _get_perspective_coeffs(startpoints, endpoints)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(interpolation, InterpolationModes):
raise TypeError("Argument interpolation should be a InterpolationModes")
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.perspective(img, coeffs, interpolation=interpolation, fill=fill) pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill)
return F_t.perspective(img, coeffs, interpolation=interpolation, fill=fill) return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill)
def vflip(img: Tensor) -> Tensor: def vflip(img: Tensor) -> Tensor:
...@@ -801,8 +869,9 @@ def _get_inverse_affine_matrix( ...@@ -801,8 +869,9 @@ def _get_inverse_affine_matrix(
def rotate( def rotate(
img: Tensor, angle: float, resample: int = 0, expand: bool = False, img: Tensor, angle: float, interpolation: InterpolationModes = InterpolationModes.NEAREST,
center: Optional[List[int]] = None, fill: Optional[int] = None expand: bool = False, center: Optional[List[int]] = None,
fill: Optional[int] = None, resample: Optional[int] = None
) -> Tensor: ) -> Tensor:
"""Rotate the image by angle. """Rotate the image by angle.
The image can be a PIL Image or a Tensor, in which case it is expected The image can be a PIL Image or a Tensor, in which case it is expected
...@@ -811,9 +880,10 @@ def rotate( ...@@ -811,9 +880,10 @@ def rotate(
Args: Args:
img (PIL Image or Tensor): image to be rotated. img (PIL Image or Tensor): image to be rotated.
angle (float or int): rotation angle value in degrees, counter-clockwise. angle (float or int): rotation angle value in degrees, counter-clockwise.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): interpolation (InterpolationModes): Desired interpolation enum defined by
An optional resampling filter. See `filters`_ for more information. :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
expand (bool, optional): Optional expansion flag. expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image. If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image. If false or omitted, make the output image the same size as the input image.
...@@ -825,6 +895,8 @@ def rotate( ...@@ -825,6 +895,8 @@ def rotate(
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
This option is not supported for Tensor input. Fill value for the area outside the transform in the output This option is not supported for Tensor input. Fill value for the area outside the transform in the output
image is always 0. image is always 0.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:interpolation: instead.
Returns: Returns:
PIL Image or Tensor: Rotated image. PIL Image or Tensor: Rotated image.
...@@ -832,14 +904,32 @@ def rotate( ...@@ -832,14 +904,32 @@ def rotate(
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
""" """
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
)
interpolation = _interpolation_modes_from_int(resample)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(angle, (int, float)): if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float") raise TypeError("Argument angle should be int or float")
if center is not None and not isinstance(center, (list, tuple)): if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence") raise TypeError("Argument center should be a sequence")
if not isinstance(interpolation, InterpolationModes):
raise TypeError("Argument interpolation should be a InterpolationModes")
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.rotate(img, angle=angle, resample=resample, expand=expand, center=center, fill=fill) pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
...@@ -850,12 +940,13 @@ def rotate( ...@@ -850,12 +940,13 @@ def rotate(
# due to current incoherence of rotation angle direction between affine and rotate implementations # due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle. # we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill) return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill)
def affine( def affine(
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
resample: int = 0, fillcolor: Optional[int] = None interpolation: InterpolationModes = InterpolationModes.NEAREST, fill: Optional[int] = None,
resample: Optional[int] = None, fillcolor: Optional[int] = None
) -> Tensor: ) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant. """Apply affine transformation on the image keeping image center invariant.
The image can be a PIL Image or a Tensor, in which case it is expected The image can be a PIL Image or a Tensor, in which case it is expected
...@@ -869,17 +960,41 @@ def affine( ...@@ -869,17 +960,41 @@ def affine(
shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction. shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction.
If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while
the second value corresponds to a shear parallel to the y axis. the second value corresponds to a shear parallel to the y axis.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): interpolation (InterpolationModes): Desired interpolation enum defined by
An optional resampling filter. See `filters`_ for more information. :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``.
If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fillcolor (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0). fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0).
This option is not supported for Tensor input. Fill value for the area outside the transform in the output This option is not supported for Tensor input. Fill value for the area outside the transform in the output
image is always 0. image is always 0.
fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:fill: instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:interpolation: instead.
Returns: Returns:
PIL Image or Tensor: Transformed image. PIL Image or Tensor: Transformed image.
""" """
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
)
interpolation = _interpolation_modes_from_int(resample)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
if fillcolor is not None:
warnings.warn(
"Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead"
)
fill = fillcolor
if not isinstance(angle, (int, float)): if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float") raise TypeError("Argument angle should be int or float")
...@@ -895,6 +1010,9 @@ def affine( ...@@ -895,6 +1010,9 @@ def affine(
if not isinstance(shear, (numbers.Number, (list, tuple))): if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError("Shear should be either a single value or a sequence of two values") raise TypeError("Shear should be either a single value or a sequence of two values")
if not isinstance(interpolation, InterpolationModes):
raise TypeError("Argument interpolation should be a InterpolationModes")
if isinstance(angle, int): if isinstance(angle, int):
angle = float(angle) angle = float(angle)
...@@ -920,12 +1038,12 @@ def affine( ...@@ -920,12 +1038,12 @@ def affine(
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
center = [img_size[0] * 0.5, img_size[1] * 0.5] center = [img_size[0] * 0.5, img_size[1] * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
translate_f = [1.0 * t for t in translate] translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear) matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear)
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
@torch.jit.unused @torch.jit.unused
......
...@@ -474,7 +474,7 @@ def _parse_fill(fill, img, min_pil_version, name="fillcolor"): ...@@ -474,7 +474,7 @@ def _parse_fill(fill, img, min_pil_version, name="fillcolor"):
@torch.jit.unused @torch.jit.unused
def affine(img, matrix, resample=0, fillcolor=None): def affine(img, matrix, interpolation=0, fill=None):
"""PRIVATE METHOD. Apply affine transformation on the PIL Image keeping image center invariant. """PRIVATE METHOD. Apply affine transformation on the PIL Image keeping image center invariant.
.. warning:: .. warning::
...@@ -485,11 +485,11 @@ def affine(img, matrix, resample=0, fillcolor=None): ...@@ -485,11 +485,11 @@ def affine(img, matrix, resample=0, fillcolor=None):
Args: Args:
img (PIL Image): image to be rotated. img (PIL Image): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter. An optional resampling filter.
See `filters`_ for more information. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) fill (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
Returns: Returns:
PIL Image: Transformed image. PIL Image: Transformed image.
...@@ -498,12 +498,12 @@ def affine(img, matrix, resample=0, fillcolor=None): ...@@ -498,12 +498,12 @@ def affine(img, matrix, resample=0, fillcolor=None):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
output_size = img.size output_size = img.size
opts = _parse_fill(fillcolor, img, '5.0.0') opts = _parse_fill(fill, img, '5.0.0')
return img.transform(output_size, Image.AFFINE, matrix, resample, **opts) return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts)
@torch.jit.unused @torch.jit.unused
def rotate(img, angle, resample=0, expand=False, center=None, fill=None): def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
"""PRIVATE METHOD. Rotate PIL image by angle. """PRIVATE METHOD. Rotate PIL image by angle.
.. warning:: .. warning::
...@@ -514,7 +514,7 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None): ...@@ -514,7 +514,7 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None):
Args: Args:
img (PIL Image): image to be rotated. img (PIL Image): image to be rotated.
angle (float or int): rotation angle value in degrees, counter-clockwise. angle (float or int): rotation angle value in degrees, counter-clockwise.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter. See `filters`_ for more information. An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
expand (bool, optional): Optional expansion flag. expand (bool, optional): Optional expansion flag.
...@@ -538,7 +538,7 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None): ...@@ -538,7 +538,7 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None):
raise TypeError("img should be PIL Image. Got {}".format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
opts = _parse_fill(fill, img, '5.2.0') opts = _parse_fill(fill, img, '5.2.0')
return img.rotate(angle, resample, expand, center, **opts) return img.rotate(angle, interpolation, expand, center, **opts)
@torch.jit.unused @torch.jit.unused
......
...@@ -757,7 +757,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con ...@@ -757,7 +757,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
return img return img
def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Tensor:
r"""PRIVATE METHOD. Resize the input Tensor to the given size. r"""PRIVATE METHOD. Resize the input Tensor to the given size.
.. warning:: .. warning::
...@@ -774,8 +774,8 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: ...@@ -774,8 +774,8 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
In torchscript mode padding as a single int is not supported, use a tuple or In torchscript mode padding as a single int is not supported, use a tuple or
list of length 1: ``[size, ]``. list of length 1: ``[size, ]``.
interpolation (int, optional): Desired interpolation. Default is bilinear (=2). Other supported values: interpolation (str): Desired interpolation. Default is "bilinear". Other supported values:
nearest(=0) and bicubic(=3). "nearest" and "bicubic".
Returns: Returns:
Tensor: Resized image. Tensor: Resized image.
...@@ -785,16 +785,10 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: ...@@ -785,16 +785,10 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
if not isinstance(size, (int, tuple, list)): if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg") raise TypeError("Got inappropriate size arg")
if not isinstance(interpolation, int): if not isinstance(interpolation, str):
raise TypeError("Got inappropriate interpolation arg") raise TypeError("Got inappropriate interpolation arg")
_interpolation_modes = { if interpolation not in ["nearest", "bilinear", "bicubic"]:
0: "nearest",
2: "bilinear",
3: "bicubic",
}
if interpolation not in _interpolation_modes:
raise ValueError("This interpolation mode is unsupported with Tensor input") raise ValueError("This interpolation mode is unsupported with Tensor input")
if isinstance(size, tuple): if isinstance(size, tuple):
...@@ -822,16 +816,14 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: ...@@ -822,16 +816,14 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
if (w <= h and w == size_w) or (h <= w and h == size_h): if (w <= h and w == size_w) or (h <= w and h == size_h):
return img return img
mode = _interpolation_modes[interpolation]
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
# Define align_corners to avoid warnings # Define align_corners to avoid warnings
align_corners = False if mode in ["bilinear", "bicubic"] else None align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(img, size=[size_h, size_w], mode=mode, align_corners=align_corners) img = interpolate(img, size=[size_h, size_w], mode=interpolation, align_corners=align_corners)
if mode == "bicubic" and out_dtype == torch.uint8: if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255) img = img.clamp(min=0, max=255)
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype) img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
...@@ -842,9 +834,9 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: ...@@ -842,9 +834,9 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
def _assert_grid_transform_inputs( def _assert_grid_transform_inputs(
img: Tensor, img: Tensor,
matrix: Optional[List[float]], matrix: Optional[List[float]],
resample: int, interpolation: str,
fillcolor: Optional[int], fill: Optional[int],
_interpolation_modes: Dict[int, str], supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None, coeffs: Optional[List[float]] = None,
): ):
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
...@@ -859,11 +851,11 @@ def _assert_grid_transform_inputs( ...@@ -859,11 +851,11 @@ def _assert_grid_transform_inputs(
if coeffs is not None and len(coeffs) != 8: if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values") raise ValueError("Argument coeffs should have 8 float values")
if fillcolor is not None: if fill is not None and not (isinstance(fill, (int, float)) and fill == 0):
warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero") warnings.warn("Argument fill is not supported for Tensor input. Fill value is zero")
if resample not in _interpolation_modes: if interpolation not in supported_interpolation_modes:
raise ValueError("Resampling mode '{}' is unsupported with Tensor input".format(resample)) raise ValueError("Interpolation mode '{}' is unsupported with Tensor input".format(interpolation))
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]: def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
...@@ -931,7 +923,7 @@ def _gen_affine_grid( ...@@ -931,7 +923,7 @@ def _gen_affine_grid(
def affine( def affine(
img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[int] = None
) -> Tensor: ) -> Tensor:
"""PRIVATE METHOD. Apply affine transformation on the Tensor image keeping image center invariant. """PRIVATE METHOD. Apply affine transformation on the Tensor image keeping image center invariant.
...@@ -943,28 +935,21 @@ def affine( ...@@ -943,28 +935,21 @@ def affine(
Args: Args:
img (Tensor): image to be rotated. img (Tensor): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation.
resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: interpolation (str): An optional resampling filter. Default is "nearest". Other supported values: "bilinear".
bilinear(=2). fill (int, optional): this option is not supported for Tensor input. Fill value for the area outside the
fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the
transform in the output image is always 0. transform in the output image is always 0.
Returns: Returns:
Tensor: Transformed image. Tensor: Transformed image.
""" """
_interpolation_modes = { _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
0: "nearest",
2: "bilinear",
}
_assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32 dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape shape = img.shape
# grid will be generated on the same device as theta and img # grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
mode = _interpolation_modes[resample] return _apply_grid_transform(img, grid, interpolation)
return _apply_grid_transform(img, grid, mode)
def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
...@@ -993,7 +978,8 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] ...@@ -993,7 +978,8 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
def rotate( def rotate(
img: Tensor, matrix: List[float], resample: int = 0, expand: bool = False, fill: Optional[int] = None img: Tensor, matrix: List[float], interpolation: str = "nearest",
expand: bool = False, fill: Optional[int] = None
) -> Tensor: ) -> Tensor:
"""PRIVATE METHOD. Rotate the Tensor image by angle. """PRIVATE METHOD. Rotate the Tensor image by angle.
...@@ -1006,8 +992,7 @@ def rotate( ...@@ -1006,8 +992,7 @@ def rotate(
img (Tensor): image to be rotated. img (Tensor): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation. matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation.
Translation part (``matrix[2]`` and ``matrix[5]``) should be in pixel coordinates. Translation part (``matrix[2]`` and ``matrix[5]``) should be in pixel coordinates.
resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: interpolation (str): An optional resampling filter. Default is "nearest". Other supported values: "bilinear".
bilinear(=2).
expand (bool, optional): Optional expansion flag. expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image. If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image. If false or omitted, make the output image the same size as the input image.
...@@ -1021,21 +1006,14 @@ def rotate( ...@@ -1021,21 +1006,14 @@ def rotate(
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
""" """
_interpolation_modes = { _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
0: "nearest",
2: "bilinear",
}
_assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes)
w, h = img.shape[-1], img.shape[-2] w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32 dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img # grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
mode = _interpolation_modes[resample] return _apply_grid_transform(img, grid, interpolation)
return _apply_grid_transform(img, grid, mode)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device): def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device):
...@@ -1072,7 +1050,7 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, ...@@ -1072,7 +1050,7 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
def perspective( def perspective(
img: Tensor, perspective_coeffs: List[float], interpolation: int = 2, fill: Optional[int] = None img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[int] = None
) -> Tensor: ) -> Tensor:
"""PRIVATE METHOD. Perform perspective transform of the given Tensor image. """PRIVATE METHOD. Perform perspective transform of the given Tensor image.
...@@ -1084,7 +1062,7 @@ def perspective( ...@@ -1084,7 +1062,7 @@ def perspective(
Args: Args:
img (Tensor): Image to be transformed. img (Tensor): Image to be transformed.
perspective_coeffs (list of float): perspective transformation coefficients. perspective_coeffs (list of float): perspective transformation coefficients.
interpolation (int): Interpolation type. Default, ``PIL.Image.BILINEAR``. interpolation (str): Interpolation type. Default, "bilinear".
fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area
outside the transform in the output image is always 0. outside the transform in the output image is always 0.
...@@ -1094,26 +1072,19 @@ def perspective( ...@@ -1094,26 +1072,19 @@ def perspective(
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('Input img should be Tensor Image') raise TypeError('Input img should be Tensor Image')
_interpolation_modes = {
0: "nearest",
2: "bilinear",
}
_assert_grid_transform_inputs( _assert_grid_transform_inputs(
img, img,
matrix=None, matrix=None,
resample=interpolation, interpolation=interpolation,
fillcolor=fill, fill=fill,
_interpolation_modes=_interpolation_modes, supported_interpolation_modes=["nearest", "bilinear"],
coeffs=perspective_coeffs coeffs=perspective_coeffs
) )
ow, oh = img.shape[-1], img.shape[-2] ow, oh = img.shape[-1], img.shape[-2]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32 dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device) grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
mode = _interpolation_modes[interpolation] return _apply_grid_transform(img, grid, interpolation)
return _apply_grid_transform(img, grid, mode)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
......
...@@ -6,7 +6,6 @@ from collections.abc import Sequence ...@@ -6,7 +6,6 @@ from collections.abc import Sequence
from typing import Tuple, List, Optional from typing import Tuple, List, Optional
import torch import torch
from PIL import Image
from torch import Tensor from torch import Tensor
try: try:
...@@ -15,21 +14,14 @@ except ImportError: ...@@ -15,21 +14,14 @@ except ImportError:
accimage = None accimage = None
from . import functional as F from . import functional as F
from .functional import InterpolationModes, _interpolation_modes_from_int
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing", "GaussianBlur"] "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationModes"]
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
Image.HAMMING: 'PIL.Image.HAMMING',
Image.BOX: 'PIL.Image.BOX',
}
class Compose: class Compose:
...@@ -242,18 +234,30 @@ class Resize(torch.nn.Module): ...@@ -242,18 +234,30 @@ class Resize(torch.nn.Module):
(size * height / width, size). (size * height / width, size).
In torchscript mode padding as single int is not supported, use a tuple or In torchscript mode padding as single int is not supported, use a tuple or
list of length 1: ``[size, ]``. list of length 1: ``[size, ]``.
interpolation (int, optional): Desired interpolation enum defined by `filters`_. interpolation (InterpolationModes): Desired interpolation enum defined by
Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``.
and ``PIL.Image.BICUBIC`` are supported. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` and
``InterpolationModes.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
""" """
def __init__(self, size, interpolation=Image.BILINEAR): def __init__(self, size, interpolation=InterpolationModes.BILINEAR):
super().__init__() super().__init__()
if not isinstance(size, (int, Sequence)): if not isinstance(size, (int, Sequence)):
raise TypeError("Size should be int or sequence. Got {}".format(type(size))) raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
if isinstance(size, Sequence) and len(size) not in (1, 2): if isinstance(size, Sequence) and len(size) not in (1, 2):
raise ValueError("If size is a sequence, it should have 1 or 2 values") raise ValueError("If size is a sequence, it should have 1 or 2 values")
self.size = size self.size = size
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation self.interpolation = interpolation
def forward(self, img): def forward(self, img):
...@@ -267,7 +271,7 @@ class Resize(torch.nn.Module): ...@@ -267,7 +271,7 @@ class Resize(torch.nn.Module):
return F.resize(img, self.size, self.interpolation) return F.resize(img, self.size, self.interpolation)
def __repr__(self): def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation] interpolate_str = self.interpolation.value
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
...@@ -659,18 +663,28 @@ class RandomPerspective(torch.nn.Module): ...@@ -659,18 +663,28 @@ class RandomPerspective(torch.nn.Module):
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
Default is 0.5. Default is 0.5.
p (float): probability of the image being transformed. Default is 0.5. p (float): probability of the image being transformed. Default is 0.5.
interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and interpolation (InterpolationModes): Desired interpolation enum defined by
``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``.
If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively. Default is 0. image. If int or float, the value is used for all bands respectively. Default is 0.
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor
input. Fill value for the area outside the transform in the output image is always 0. input. Fill value for the area outside the transform in the output image is always 0.
""" """
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0): def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationModes.BILINEAR, fill=0):
super().__init__() super().__init__()
self.p = p self.p = p
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation self.interpolation = interpolation
self.distortion_scale = distortion_scale self.distortion_scale = distortion_scale
self.fill = fill self.fill = fill
...@@ -744,12 +758,15 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -744,12 +758,15 @@ class RandomResizedCrop(torch.nn.Module):
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
scale (tuple of float): scale range of the cropped image before resizing, relatively to the origin image. scale (tuple of float): scale range of the cropped image before resizing, relatively to the origin image.
ratio (tuple of float): aspect ratio range of the cropped image before resizing. ratio (tuple of float): aspect ratio range of the cropped image before resizing.
interpolation (int): Desired interpolation enum defined by `filters`_. interpolation (InterpolationModes): Desired interpolation enum defined by
Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``.
and ``PIL.Image.BICUBIC`` are supported. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` and
``InterpolationModes.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
""" """
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationModes.BILINEAR):
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
...@@ -760,6 +777,14 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -760,6 +777,14 @@ class RandomResizedCrop(torch.nn.Module):
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)") warnings.warn("Scale and ratio should be of kind (min, max)")
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation self.interpolation = interpolation
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
...@@ -824,7 +849,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -824,7 +849,7 @@ class RandomResizedCrop(torch.nn.Module):
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self): def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation] interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
...@@ -1122,9 +1147,10 @@ class RandomRotation(torch.nn.Module): ...@@ -1122,9 +1147,10 @@ class RandomRotation(torch.nn.Module):
degrees (sequence or float or int): Range of degrees to select from. degrees (sequence or float or int): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees). will be (-degrees, +degrees).
resample (int, optional): An optional resampling filter. See `filters`_ for more information. interpolation (InterpolationModes): Desired interpolation enum defined by
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``.
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
expand (bool, optional): Optional expansion flag. expand (bool, optional): Optional expansion flag.
If true, expands the output to make it large enough to hold the entire rotated image. If true, expands the output to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image. If false or omitted, make the output image the same size as the input image.
...@@ -1136,13 +1162,31 @@ class RandomRotation(torch.nn.Module): ...@@ -1136,13 +1162,31 @@ class RandomRotation(torch.nn.Module):
Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0.
This option is not supported for Tensor input. Fill value for the area outside the transform in the output This option is not supported for Tensor input. Fill value for the area outside the transform in the output
image is always 0. image is always 0.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:interpolation: instead.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
""" """
def __init__(self, degrees, resample=False, expand=False, center=None, fill=None): def __init__(
self, degrees, interpolation=InterpolationModes.NEAREST, expand=False, center=None, fill=None, resample=None
):
super().__init__() super().__init__()
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
)
interpolation = _interpolation_modes_from_int(resample)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
if center is not None: if center is not None:
...@@ -1150,7 +1194,7 @@ class RandomRotation(torch.nn.Module): ...@@ -1150,7 +1194,7 @@ class RandomRotation(torch.nn.Module):
self.center = center self.center = center
self.resample = resample self.resample = self.interpolation = interpolation
self.expand = expand self.expand = expand
self.fill = fill self.fill = fill
...@@ -1173,11 +1217,12 @@ class RandomRotation(torch.nn.Module): ...@@ -1173,11 +1217,12 @@ class RandomRotation(torch.nn.Module):
PIL Image or Tensor: Rotated image. PIL Image or Tensor: Rotated image.
""" """
angle = self.get_params(self.degrees) angle = self.get_params(self.degrees)
return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill) return F.rotate(img, angle, self.interpolation, self.expand, self.center, self.fill)
def __repr__(self): def __repr__(self):
interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
format_string += ', resample={0}'.format(self.resample) format_string += ', interpolation={0}'.format(interpolate_str)
format_string += ', expand={0}'.format(self.expand) format_string += ', expand={0}'.format(self.expand)
if self.center is not None: if self.center is not None:
format_string += ', center={0}'.format(self.center) format_string += ', center={0}'.format(self.center)
...@@ -1208,19 +1253,47 @@ class RandomAffine(torch.nn.Module): ...@@ -1208,19 +1253,47 @@ class RandomAffine(torch.nn.Module):
range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
Will not apply shear by default. Will not apply shear by default.
resample (int, optional): An optional resampling filter. See `filters`_ for more information. interpolation (InterpolationModes): Desired interpolation enum defined by
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``.
If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported.
fillcolor (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area
outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor
input. Fill value for the area outside the transform in the output image is always 0. input. Fill value for the area outside the transform in the output image is always 0.
fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:fill: instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:interpolation: instead.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
""" """
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=0, fillcolor=0): def __init__(
self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationModes.NEAREST, fill=0,
fillcolor=None, resample=None
):
super().__init__() super().__init__()
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
)
interpolation = _interpolation_modes_from_int(resample)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationModes instead of int. "
"Please, use InterpolationModes enum."
)
interpolation = _interpolation_modes_from_int(interpolation)
if fillcolor is not None:
warnings.warn(
"Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead"
)
fill = fillcolor
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
if translate is not None: if translate is not None:
...@@ -1242,8 +1315,8 @@ class RandomAffine(torch.nn.Module): ...@@ -1242,8 +1315,8 @@ class RandomAffine(torch.nn.Module):
else: else:
self.shear = shear self.shear = shear
self.resample = resample self.resample = self.interpolation = interpolation
self.fillcolor = fillcolor self.fillcolor = self.fill = fill
@staticmethod @staticmethod
def get_params( def get_params(
...@@ -1294,7 +1367,7 @@ class RandomAffine(torch.nn.Module): ...@@ -1294,7 +1367,7 @@ class RandomAffine(torch.nn.Module):
img_size = F._get_image_size(img) img_size = F._get_image_size(img)
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) return F.affine(img, *ret, interpolation=self.interpolation, fill=self.fill)
def __repr__(self): def __repr__(self):
s = '{name}(degrees={degrees}' s = '{name}(degrees={degrees}'
...@@ -1304,13 +1377,13 @@ class RandomAffine(torch.nn.Module): ...@@ -1304,13 +1377,13 @@ class RandomAffine(torch.nn.Module):
s += ', scale={scale}' s += ', scale={scale}'
if self.shear is not None: if self.shear is not None:
s += ', shear={shear}' s += ', shear={shear}'
if self.resample > 0: if self.interpolation != InterpolationModes.NEAREST:
s += ', resample={resample}' s += ', interpolation={interpolation}'
if self.fillcolor != 0: if self.fill != 0:
s += ', fillcolor={fillcolor}' s += ', fill={fill}'
s += ')' s += ')'
d = dict(self.__dict__) d = dict(self.__dict__)
d['resample'] = _pil_interpolation_to_str[d['resample']] d['interpolation'] = self.interpolation.value
return s.format(name=self.__class__.__name__, **d) return s.format(name=self.__class__.__name__, **d)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment