Unverified Commit 8fdaeb03 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Image and Mask can accept PIL images (#7231)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent e21e4362
...@@ -835,7 +835,7 @@ KERNEL_INFOS.extend( ...@@ -835,7 +835,7 @@ KERNEL_INFOS.extend(
F.rotate_bounding_box, F.rotate_bounding_box,
sample_inputs_fn=sample_inputs_rotate_bounding_box, sample_inputs_fn=sample_inputs_rotate_bounding_box,
closeness_kwargs={ closeness_kwargs={
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-6, rtol=1e-6), **scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), **scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
}, },
), ),
......
import pytest import pytest
import torch import torch
from PIL import Image
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
...@@ -130,3 +132,30 @@ def test_wrap_like(): ...@@ -130,3 +132,30 @@ def test_wrap_like():
assert type(label_new) is datapoints.Label assert type(label_new) is datapoints.Label
assert label_new.data_ptr() == output.data_ptr() assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories assert label_new.categories is label.categories
@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
def test_image_instance(data):
image = datapoints.Image(data)
assert isinstance(image, torch.Tensor)
assert image.ndim == 3 and image.shape[0] == 3
@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
def test_mask_instance(data):
mask = datapoints.Mask(data)
assert isinstance(mask, torch.Tensor)
assert mask.ndim == 3 and mask.shape[0] == 1
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]])
@pytest.mark.parametrize(
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
)
def test_bbox_instance(data, format):
bboxes = datapoints.BoundingBox(data, format=format, spatial_size=(32, 32))
assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str):
format = datapoints.BoundingBoxFormat.from_str(format.upper())
assert bboxes.format == format
...@@ -99,7 +99,7 @@ def identity(item): ...@@ -99,7 +99,7 @@ def identity(item):
def pil_image_to_mask(pil_image): def pil_image_to_mask(pil_image):
return datapoints.Mask(F.to_image_tensor(pil_image).squeeze(0)) return datapoints.Mask(pil_image)
def list_of_dicts_to_dict_of_lists(list_of_dicts): def list_of_dicts_to_dict_of_lists(list_of_dicts):
......
...@@ -23,6 +23,11 @@ class Image(Datapoint): ...@@ -23,6 +23,11 @@ class Image(Datapoint):
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None, requires_grad: Optional[bool] = None,
) -> Image: ) -> Image:
if isinstance(data, PIL.Image.Image):
from torchvision.prototype.transforms import functional as F
data = F.pil_to_tensor(data)
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if tensor.ndim < 2: if tensor.ndim < 2:
raise ValueError raise ValueError
......
...@@ -2,6 +2,7 @@ from __future__ import annotations ...@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import PIL.Image
import torch import torch
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
...@@ -21,6 +22,11 @@ class Mask(Datapoint): ...@@ -21,6 +22,11 @@ class Mask(Datapoint):
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None, requires_grad: Optional[bool] = None,
) -> Mask: ) -> Mask:
if isinstance(data, PIL.Image.Image):
from torchvision.prototype.transforms import functional as F
data = F.pil_to_tensor(data)
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor) return cls._wrap(tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment