Unverified Commit bdf16222 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add support for instance checks on dataset wrappers (#7239)

parent 9b4ec8df
......@@ -178,9 +178,7 @@ def get_coco_api_from_dataset(dataset):
break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
if isinstance(dataset, torchvision.datasets.CocoDetection):
return dataset.coco
return convert_to_coco_api(dataset)
......
......@@ -164,9 +164,7 @@ def compute_aspect_ratios(dataset, indices=None):
if hasattr(dataset, "get_height_and_width"):
return _compute_aspect_ratios_custom_dataset(dataset, indices)
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
):
if isinstance(dataset, torchvision.datasets.CocoDetection):
return _compute_aspect_ratios_coco_dataset(dataset, indices)
if isinstance(dataset, torchvision.datasets.VOCDetection):
......
......@@ -571,7 +571,7 @@ class DatasetTestCase(unittest.TestCase):
from torchvision.datasets import wrap_dataset_for_transforms_v2
try:
with self.create_dataset(config) as (dataset, _):
with self.create_dataset(config) as (dataset, info):
for target_keys in [None, "all"]:
if target_keys is not None and self.DATASET_CLASS not in {
torchvision.datasets.CocoDetection,
......@@ -584,8 +584,10 @@ class DatasetTestCase(unittest.TestCase):
continue
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
wrapped_sample = wrapped_dataset[0]
assert isinstance(wrapped_dataset, self.DATASET_CLASS)
assert len(wrapped_dataset) == info["num_examples"]
wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
......
......@@ -8,7 +8,6 @@ import contextlib
from collections import defaultdict
import torch
from torch.utils.data import Dataset
from torchvision import datapoints, datasets
from torchvision.transforms.v2 import functional as F
......@@ -98,7 +97,16 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
f"but got {target_keys}"
)
return VisionDatasetDatapointWrapper(dataset, target_keys)
# Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
# "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the
# original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
# while we can still inject everything that we need.
wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetDatapointWrapper, type(dataset)), {})
# Since VisionDatasetDatapointWrapper comes before ImageNet in the MRO, calling the class hits
# VisionDatasetDatapointWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
# ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
# have the existing instance as attribute on the new object.
return wrapped_dataset_cls(dataset, target_keys)
class WrapperFactories(dict):
......@@ -117,7 +125,7 @@ class WrapperFactories(dict):
WRAPPER_FACTORIES = WrapperFactories()
class VisionDatasetDatapointWrapper(Dataset):
class VisionDatasetDatapointWrapper:
def __init__(self, dataset, target_keys):
dataset_cls = type(dataset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment