Unverified Commit 6fcf0a27 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Fix resize when size == small_edge_size and max_size isn't None (#5409)

* Fix resize when size == small_edge_size and max_size isn't None

* Better test name
parent 26fe8fad
...@@ -440,6 +440,19 @@ def test_resize_antialias_error(): ...@@ -440,6 +440,19 @@ def test_resize_antialias_error():
t(img) t(img)
@pytest.mark.parametrize("height, width", ((32, 64), (64, 32)))
def test_resize_size_equals_small_edge_size(height, width):
# Non-regression test for https://github.com/pytorch/vision/issues/5405
# max_size used to be ignored if size == small_edge_size
max_size = 40
img = Image.new("RGB", size=(width, height), color=127)
small_edge = min(height, width)
t = transforms.Resize(small_edge, max_size=max_size)
result = t(img)
assert max(result.size) == max_size
class TestPad: class TestPad:
def test_pad(self): def test_pad(self):
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
......
...@@ -240,9 +240,6 @@ def resize( ...@@ -240,9 +240,6 @@ def resize(
w, h = img.size w, h = img.size
short, long = (w, h) if w <= h else (h, w) short, long = (w, h) if w <= h else (h, w)
if short == size:
return img
new_short, new_long = size, int(size * long / short) new_short, new_long = size, int(size * long / short)
if max_size is not None: if max_size is not None:
...@@ -255,6 +252,10 @@ def resize( ...@@ -255,6 +252,10 @@ def resize(
new_short, new_long = int(max_size * new_short / new_long), max_size new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
if (w, h) == (new_w, new_h):
return img
else:
return img.resize((new_w, new_h), interpolation) return img.resize((new_w, new_h), interpolation)
else: else:
if max_size is not None: if max_size is not None:
......
...@@ -457,9 +457,6 @@ def resize( ...@@ -457,9 +457,6 @@ def resize(
short, long = (w, h) if w <= h else (h, w) short, long = (w, h) if w <= h else (h, w)
requested_new_short = size if isinstance(size, int) else size[0] requested_new_short = size if isinstance(size, int) else size[0]
if short == requested_new_short:
return img
new_short, new_long = requested_new_short, int(requested_new_short * long / short) new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None: if max_size is not None:
...@@ -473,6 +470,9 @@ def resize( ...@@ -473,6 +470,9 @@ def resize(
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
if (w, h) == (new_w, new_h):
return img
else: # specified both h and w else: # specified both h and w
new_w, new_h = size[1], size[0] new_w, new_h = size[1], size[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment