Unverified Commit 36d0e3e6 authored by Mantas's avatar Mantas Committed by GitHub
Browse files

Allow 2D numpy arrays as inputs for `to_image` (#8256)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 8e712070
...@@ -5182,6 +5182,11 @@ class TestToImage: ...@@ -5182,6 +5182,11 @@ class TestToImage:
if isinstance(input, torch.Tensor): if isinstance(input, torch.Tensor):
assert output.data_ptr() == input.data_ptr() assert output.data_ptr() == input.data_ptr()
def test_2d_np_array(self):
# Non-regression test for https://github.com/pytorch/vision/issues/8255
input = np.random.rand(10, 10)
assert F.to_image(input).shape == (1, 10, 10)
def test_functional_error(self): def test_functional_error(self):
with pytest.raises(TypeError, match="Input can either be a pure Tensor, a numpy array, or a PIL image"): with pytest.raises(TypeError, match="Input can either be a pure Tensor, a numpy array, or a PIL image"):
F.to_image(object()) F.to_image(object())
......
...@@ -11,7 +11,7 @@ from torchvision.transforms import functional as _F ...@@ -11,7 +11,7 @@ from torchvision.transforms import functional as _F
def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tensors.Image: def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tensors.Image:
"""See :class:`~torchvision.transforms.v2.ToImage` for details.""" """See :class:`~torchvision.transforms.v2.ToImage` for details."""
if isinstance(inpt, np.ndarray): if isinstance(inpt, np.ndarray):
output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous() output = torch.from_numpy(np.atleast_3d(inpt)).permute((2, 0, 1)).contiguous()
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
output = pil_to_tensor(inpt) output = pil_to_tensor(inpt)
elif isinstance(inpt, torch.Tensor): elif isinstance(inpt, torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment