Unverified Commit 8088cc94 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixed bug with Resize.size if input is integer (#2869)

parent d1e134c7
...@@ -219,6 +219,11 @@ class Tester(unittest.TestCase): ...@@ -219,6 +219,11 @@ class Tester(unittest.TestCase):
width = random.randint(24, 32) * 2 width = random.randint(24, 32) * 2
osize = random.randint(5, 12) * 2 osize = random.randint(5, 12) * 2
# TODO: Check output size check for bug-fix, improve this later
t = transforms.Resize(osize)
self.assertTrue(isinstance(t.size, int))
self.assertEqual(t.size, osize)
img = torch.ones(3, height, width) img = torch.ones(3, height, width)
result = transforms.Compose([ result = transforms.Compose([
transforms.ToPILImage(), transforms.ToPILImage(),
......
...@@ -280,6 +280,17 @@ class Tester(TransformsTester): ...@@ -280,6 +280,17 @@ class Tester(TransformsTester):
) )
def test_resize(self): def test_resize(self):
# TODO: Minimal check for bug-fix, improve this later
x = torch.rand(3, 32, 46)
t = T.Resize(size=38)
y = t(x)
# If size is an int, smaller edge of the image will be matched to this number.
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
self.assertTrue(isinstance(y, torch.Tensor))
self.assertEqual(y.shape[1], 38)
self.assertEqual(y.shape[2], int(38 * 46 / 32))
tensor, _ = self._create_data(height=34, width=36, device=self.device) tensor, _ = self._create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
script_fn = torch.jit.script(F.resize) script_fn = torch.jit.script(F.resize)
......
...@@ -249,7 +249,11 @@ class Resize(torch.nn.Module): ...@@ -249,7 +249,11 @@ class Resize(torch.nn.Module):
def __init__(self, size, interpolation=Image.BILINEAR): def __init__(self, size, interpolation=Image.BILINEAR):
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") if not isinstance(size, (int, Sequence)):
raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
if isinstance(size, Sequence) and len(size) not in (1, 2):
raise ValueError("If size is a sequence, it should have 1 or 2 values")
self.size = size
self.interpolation = interpolation self.interpolation = interpolation
def forward(self, img): def forward(self, img):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment