Unverified Commit 15c166ac authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

refactor to_pil_image and align array with tensor inputs (#8097)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent a0fcd083
...@@ -661,7 +661,7 @@ class TestToPil: ...@@ -661,7 +661,7 @@ class TestToPil:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"img_data, expected_mode", "img_data, expected_mode",
[ [
(torch.Tensor(4, 4, 1).uniform_().numpy(), "F"), (torch.Tensor(4, 4, 1).uniform_().numpy(), "L"),
(torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"), (torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"),
(torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"), (torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"),
(torch.IntTensor(4, 4, 1).random_().numpy(), "I"), (torch.IntTensor(4, 4, 1).random_().numpy(), "I"),
...@@ -671,6 +671,8 @@ class TestToPil: ...@@ -671,6 +671,8 @@ class TestToPil:
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
img = transform(img_data) img = transform(img_data)
assert img.mode == expected_mode assert img.mode == expected_mode
if np.issubdtype(img_data.dtype, np.floating):
img_data = (img_data * 255).astype(np.uint8)
# note: we explicitly convert img's dtype because pytorch doesn't support uint16 # note: we explicitly convert img's dtype because pytorch doesn't support uint16
# and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array # and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype)) torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype))
...@@ -741,7 +743,7 @@ class TestToPil: ...@@ -741,7 +743,7 @@ class TestToPil:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"img_data, expected_mode", "img_data, expected_mode",
[ [
(torch.Tensor(4, 4).uniform_().numpy(), "F"), (torch.Tensor(4, 4).uniform_().numpy(), "L"),
(torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"), (torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"),
(torch.ShortTensor(4, 4).random_().numpy(), "I;16"), (torch.ShortTensor(4, 4).random_().numpy(), "I;16"),
(torch.IntTensor(4, 4).random_().numpy(), "I"), (torch.IntTensor(4, 4).random_().numpy(), "I"),
...@@ -751,6 +753,8 @@ class TestToPil: ...@@ -751,6 +753,8 @@ class TestToPil:
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
img = transform(img_data) img = transform(img_data)
assert img.mode == expected_mode assert img.mode == expected_mode
if np.issubdtype(img_data.dtype, np.floating):
img_data = (img_data * 255).astype(np.uint8)
np.testing.assert_allclose(img_data, img) np.testing.assert_allclose(img_data, img)
@pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"]) @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"])
...@@ -874,8 +878,6 @@ class TestToPil: ...@@ -874,8 +878,6 @@ class TestToPil:
trans(np.ones([4, 4, 1], np.uint16)) trans(np.ones([4, 4, 1], np.uint16))
with pytest.raises(TypeError, match=reg_msg): with pytest.raises(TypeError, match=reg_msg):
trans(np.ones([4, 4, 1], np.uint32)) trans(np.ones([4, 4, 1], np.uint32))
with pytest.raises(TypeError, match=reg_msg):
trans(np.ones([4, 4, 1], np.float64))
with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."): with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
transforms.ToPILImage()(np.ones([1, 4, 4, 3])) transforms.ToPILImage()(np.ones([1, 4, 4, 3]))
......
...@@ -258,41 +258,26 @@ def to_pil_image(pic, mode=None): ...@@ -258,41 +258,26 @@ def to_pil_image(pic, mode=None):
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(to_pil_image) _log_api_usage_once(to_pil_image)
if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): if isinstance(pic, torch.Tensor):
if pic.ndim == 3:
pic = pic.permute((1, 2, 0))
pic = pic.numpy(force=True)
elif not isinstance(pic, np.ndarray):
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.") raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
elif isinstance(pic, torch.Tensor): if pic.ndim == 2:
if pic.ndimension() not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.")
elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW)
pic = pic.unsqueeze(0)
# check number of channels
if pic.shape[-3] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.")
elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC) # if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2) pic = np.expand_dims(pic, 2)
if pic.ndim != 3:
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
# check number of channels
if pic.shape[-1] > 4: if pic.shape[-1] > 4:
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.") raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
npimg = pic npimg = pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != "F":
pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray): if np.issubdtype(npimg.dtype, np.floating) and mode != "F":
raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}") npimg = (npimg * 255).astype(np.uint8)
if npimg.shape[2] == 1: if npimg.shape[2] == 1:
expected_mode = None expected_mode = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment