You need to sign in or sign up before continuing.
Unverified Commit 8fdaeb03 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Image and Mask can accept PIL images (#7231)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent e21e4362
......@@ -835,7 +835,7 @@ KERNEL_INFOS.extend(
F.rotate_bounding_box,
sample_inputs_fn=sample_inputs_rotate_bounding_box,
closeness_kwargs={
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5),
},
),
......
import pytest
import torch
from PIL import Image
from torchvision.prototype import datapoints
......@@ -130,3 +132,30 @@ def test_wrap_like():
assert type(label_new) is datapoints.Label
assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories
@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
def test_image_instance(data):
image = datapoints.Image(data)
assert isinstance(image, torch.Tensor)
assert image.ndim == 3 and image.shape[0] == 3
@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
def test_mask_instance(data):
mask = datapoints.Mask(data)
assert isinstance(mask, torch.Tensor)
assert mask.ndim == 3 and mask.shape[0] == 1
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]])
@pytest.mark.parametrize(
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
)
def test_bbox_instance(data, format):
bboxes = datapoints.BoundingBox(data, format=format, spatial_size=(32, 32))
assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str):
format = datapoints.BoundingBoxFormat.from_str(format.upper())
assert bboxes.format == format
......@@ -99,7 +99,7 @@ def identity(item):
def pil_image_to_mask(pil_image):
return datapoints.Mask(F.to_image_tensor(pil_image).squeeze(0))
return datapoints.Mask(pil_image)
def list_of_dicts_to_dict_of_lists(list_of_dicts):
......
......@@ -23,6 +23,11 @@ class Image(Datapoint):
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> Image:
if isinstance(data, PIL.Image.Image):
from torchvision.prototype.transforms import functional as F
data = F.pil_to_tensor(data)
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if tensor.ndim < 2:
raise ValueError
......
......@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any, List, Optional, Tuple, Union
import PIL.Image
import torch
from torchvision.transforms import InterpolationMode
......@@ -21,6 +22,11 @@ class Mask(Datapoint):
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> Mask:
if isinstance(data, PIL.Image.Image):
from torchvision.prototype.transforms import functional as F
data = F.pil_to_tensor(data)
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment