Unverified Commit d744da93 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow subclasses in dataset wrappers (#7236)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent b570f2c1
......@@ -596,7 +596,7 @@ class DatasetTestCase(unittest.TestCase):
wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error:
if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"):
if str(error).startswith(f"No wrapper exists for dataset class {type(dataset).__name__}"):
return
raise error
except RuntimeError as error:
......
import re
import pytest
import torch
from PIL import Image
from torchvision import datasets
from torchvision.prototype import datapoints
......@@ -159,3 +163,43 @@ def test_bbox_instance(data, format):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat.from_str(format.upper())
assert bboxes.format == format
class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datapoints.wrap_dataset_for_transforms_v2(unknown_object)
def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass
dataset = MyVisionDataset("root")
with pytest.raises(TypeError, match="No wrapper exist"):
datapoints.wrap_dataset_for_transforms_v2(dataset)
def test_missing_wrapper(self):
dataset = datasets.FakeData()
with pytest.raises(TypeError, match="please open an issue"):
datapoints.wrap_dataset_for_transforms_v2(dataset)
def test_subclass(self, mocker):
sentinel = object()
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
)
class MyFakeData(datasets.FakeData):
pass
dataset = MyFakeData()
wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset)
assert wrapped_dataset[0] is sentinel
......@@ -39,16 +39,26 @@ WRAPPER_FACTORIES = WrapperFactories()
class VisionDatasetDatapointWrapper(Dataset):
def __init__(self, dataset):
dataset_cls = type(dataset)
wrapper_factory = WRAPPER_FACTORIES.get(dataset_cls)
if wrapper_factory is None:
# TODO: If we have documentation on how to do that, put a link in the error message.
msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
if dataset_cls in datasets.__dict__.values():
msg = (
f"{msg} If an automated wrapper for this dataset would be useful for you, "
f"please open an issue at https://github.com/pytorch/vision/issues."
)
raise TypeError(msg)
if not isinstance(dataset, datasets.VisionDataset):
raise TypeError(
f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
f"but got a '{dataset_cls.__name__}' instead."
)
for cls in dataset_cls.mro():
if cls in WRAPPER_FACTORIES:
wrapper_factory = WRAPPER_FACTORIES[cls]
break
elif cls is datasets.VisionDataset:
# TODO: If we have documentation on how to do that, put a link in the error message.
msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
if dataset_cls in datasets.__dict__.values():
msg = (
f"{msg} If an automated wrapper for this dataset would be useful for you, "
f"please open an issue at https://github.com/pytorch/vision/issues."
)
raise TypeError(msg)
self._dataset = dataset
self._wrapper = wrapper_factory(dataset)
......@@ -98,6 +108,13 @@ def identity(item):
return item
def identity_wrapper_factory(dataset):
def wrapper(idx, sample):
return sample
return wrapper
def pil_image_to_mask(pil_image):
return datapoints.Mask(pil_image)
......@@ -125,10 +142,7 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
def classification_wrapper_factory(dataset):
def wrapper(idx, sample):
return sample
return wrapper
return identity_wrapper_factory(dataset)
for dataset_cls in [
......@@ -237,6 +251,9 @@ def coco_dectection_wrapper_factory(dataset):
return wrapper
WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)
VOC_DETECTION_CATEGORIES = [
"__background__",
"aeroplane",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment