Unverified Commit bdf16222 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add support for instance checks on dataset wrappers (#7239)

parent 9b4ec8df
...@@ -178,9 +178,7 @@ def get_coco_api_from_dataset(dataset): ...@@ -178,9 +178,7 @@ def get_coco_api_from_dataset(dataset):
break break
if isinstance(dataset, torch.utils.data.Subset): if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance( if isinstance(dataset, torchvision.datasets.CocoDetection):
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
return dataset.coco return dataset.coco
return convert_to_coco_api(dataset) return convert_to_coco_api(dataset)
......
...@@ -164,9 +164,7 @@ def compute_aspect_ratios(dataset, indices=None): ...@@ -164,9 +164,7 @@ def compute_aspect_ratios(dataset, indices=None):
if hasattr(dataset, "get_height_and_width"): if hasattr(dataset, "get_height_and_width"):
return _compute_aspect_ratios_custom_dataset(dataset, indices) return _compute_aspect_ratios_custom_dataset(dataset, indices)
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance( if isinstance(dataset, torchvision.datasets.CocoDetection):
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
return _compute_aspect_ratios_coco_dataset(dataset, indices) return _compute_aspect_ratios_coco_dataset(dataset, indices)
if isinstance(dataset, torchvision.datasets.VOCDetection): if isinstance(dataset, torchvision.datasets.VOCDetection):
......
...@@ -571,7 +571,7 @@ class DatasetTestCase(unittest.TestCase): ...@@ -571,7 +571,7 @@ class DatasetTestCase(unittest.TestCase):
from torchvision.datasets import wrap_dataset_for_transforms_v2 from torchvision.datasets import wrap_dataset_for_transforms_v2
try: try:
with self.create_dataset(config) as (dataset, _): with self.create_dataset(config) as (dataset, info):
for target_keys in [None, "all"]: for target_keys in [None, "all"]:
if target_keys is not None and self.DATASET_CLASS not in { if target_keys is not None and self.DATASET_CLASS not in {
torchvision.datasets.CocoDetection, torchvision.datasets.CocoDetection,
...@@ -584,8 +584,10 @@ class DatasetTestCase(unittest.TestCase): ...@@ -584,8 +584,10 @@ class DatasetTestCase(unittest.TestCase):
continue continue
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
wrapped_sample = wrapped_dataset[0] assert isinstance(wrapped_dataset, self.DATASET_CLASS)
assert len(wrapped_dataset) == info["num_examples"]
wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error: except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}" msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
......
...@@ -8,7 +8,6 @@ import contextlib ...@@ -8,7 +8,6 @@ import contextlib
from collections import defaultdict from collections import defaultdict
import torch import torch
from torch.utils.data import Dataset
from torchvision import datapoints, datasets from torchvision import datapoints, datasets
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
...@@ -98,7 +97,16 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None): ...@@ -98,7 +97,16 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
f"but got {target_keys}" f"but got {target_keys}"
) )
return VisionDatasetDatapointWrapper(dataset, target_keys) # Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
# "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the
# original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
# while we can still inject everything that we need.
wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetDatapointWrapper, type(dataset)), {})
# Since VisionDatasetDatapointWrapper comes before ImageNet in the MRO, calling the class hits
# VisionDatasetDatapointWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
# ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
# have the existing instance as attribute on the new object.
return wrapped_dataset_cls(dataset, target_keys)
class WrapperFactories(dict): class WrapperFactories(dict):
...@@ -117,7 +125,7 @@ class WrapperFactories(dict): ...@@ -117,7 +125,7 @@ class WrapperFactories(dict):
WRAPPER_FACTORIES = WrapperFactories() WRAPPER_FACTORIES = WrapperFactories()
class VisionDatasetDatapointWrapper(Dataset): class VisionDatasetDatapointWrapper:
def __init__(self, dataset, target_keys): def __init__(self, dataset, target_keys):
dataset_cls = type(dataset) dataset_cls = type(dataset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment