Unverified Commit d744da93 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

allow subclasses in dataset wrappers (#7236)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent b570f2c1
...@@ -596,7 +596,7 @@ class DatasetTestCase(unittest.TestCase): ...@@ -596,7 +596,7 @@ class DatasetTestCase(unittest.TestCase):
wrapped_sample = wrapped_dataset[0] wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error: except TypeError as error:
if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"): if str(error).startswith(f"No wrapper exists for dataset class {type(dataset).__name__}"):
return return
raise error raise error
except RuntimeError as error: except RuntimeError as error:
......
import re
import pytest import pytest
import torch import torch
from PIL import Image from PIL import Image
from torchvision import datasets
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
...@@ -159,3 +163,43 @@ def test_bbox_instance(data, format): ...@@ -159,3 +163,43 @@ def test_bbox_instance(data, format):
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat.from_str(format.upper()) format = datapoints.BoundingBoxFormat.from_str(format.upper())
assert bboxes.format == format assert bboxes.format == format
class TestDatasetWrapper:
def test_unknown_type(self):
unknown_object = object()
with pytest.raises(
TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`")
):
datapoints.wrap_dataset_for_transforms_v2(unknown_object)
def test_unknown_dataset(self):
class MyVisionDataset(datasets.VisionDataset):
pass
dataset = MyVisionDataset("root")
with pytest.raises(TypeError, match="No wrapper exist"):
datapoints.wrap_dataset_for_transforms_v2(dataset)
def test_missing_wrapper(self):
dataset = datasets.FakeData()
with pytest.raises(TypeError, match="please open an issue"):
datapoints.wrap_dataset_for_transforms_v2(dataset)
def test_subclass(self, mocker):
sentinel = object()
mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel},
)
class MyFakeData(datasets.FakeData):
pass
dataset = MyFakeData()
wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset)
assert wrapped_dataset[0] is sentinel
...@@ -39,16 +39,26 @@ WRAPPER_FACTORIES = WrapperFactories() ...@@ -39,16 +39,26 @@ WRAPPER_FACTORIES = WrapperFactories()
class VisionDatasetDatapointWrapper(Dataset): class VisionDatasetDatapointWrapper(Dataset):
def __init__(self, dataset): def __init__(self, dataset):
dataset_cls = type(dataset) dataset_cls = type(dataset)
wrapper_factory = WRAPPER_FACTORIES.get(dataset_cls)
if wrapper_factory is None: if not isinstance(dataset, datasets.VisionDataset):
# TODO: If we have documentation on how to do that, put a link in the error message. raise TypeError(
msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself." f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
if dataset_cls in datasets.__dict__.values(): f"but got a '{dataset_cls.__name__}' instead."
msg = ( )
f"{msg} If an automated wrapper for this dataset would be useful for you, "
f"please open an issue at https://github.com/pytorch/vision/issues." for cls in dataset_cls.mro():
) if cls in WRAPPER_FACTORIES:
raise TypeError(msg) wrapper_factory = WRAPPER_FACTORIES[cls]
break
elif cls is datasets.VisionDataset:
# TODO: If we have documentation on how to do that, put a link in the error message.
msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
if dataset_cls in datasets.__dict__.values():
msg = (
f"{msg} If an automated wrapper for this dataset would be useful for you, "
f"please open an issue at https://github.com/pytorch/vision/issues."
)
raise TypeError(msg)
self._dataset = dataset self._dataset = dataset
self._wrapper = wrapper_factory(dataset) self._wrapper = wrapper_factory(dataset)
...@@ -98,6 +108,13 @@ def identity(item): ...@@ -98,6 +108,13 @@ def identity(item):
return item return item
def identity_wrapper_factory(dataset):
def wrapper(idx, sample):
return sample
return wrapper
def pil_image_to_mask(pil_image): def pil_image_to_mask(pil_image):
return datapoints.Mask(pil_image) return datapoints.Mask(pil_image)
...@@ -125,10 +142,7 @@ def wrap_target_by_type(target, *, target_types, type_wrappers): ...@@ -125,10 +142,7 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
def classification_wrapper_factory(dataset): def classification_wrapper_factory(dataset):
def wrapper(idx, sample): return identity_wrapper_factory(dataset)
return sample
return wrapper
for dataset_cls in [ for dataset_cls in [
...@@ -237,6 +251,9 @@ def coco_dectection_wrapper_factory(dataset): ...@@ -237,6 +251,9 @@ def coco_dectection_wrapper_factory(dataset):
return wrapper return wrapper
WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)
VOC_DETECTION_CATEGORIES = [ VOC_DETECTION_CATEGORIES = [
"__background__", "__background__",
"aeroplane", "aeroplane",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment