Unverified Commit 054432d2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

enforce pickleability for v2 transforms and wrapped datasets (#7860)

parent 92882b69
......@@ -5,6 +5,7 @@ import inspect
import itertools
import os
import pathlib
import platform
import random
import shutil
import string
......@@ -548,7 +549,7 @@ class DatasetTestCase(unittest.TestCase):
@test_all_configs
def test_num_examples(self, config):
with self.create_dataset(config) as (dataset, info):
assert len(dataset) == info["num_examples"]
assert len(list(dataset)) == len(dataset) == info["num_examples"]
@test_all_configs
def test_transforms(self, config):
......@@ -692,6 +693,31 @@ class VideoDatasetTestCase(DatasetTestCase):
super().test_transforms_v2_wrapper.__wrapped__(self, config)
def _no_collate(batch):
return batch
def check_transforms_v2_wrapper_spawn(dataset):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
# we are enforcing here.
if platform.system() != "Darwin":
pytest.skip("Multiprocessing spawning is only checked on macOS.")
from torch.utils.data import DataLoader
from torchvision import datapoints
from torchvision.datasets import wrap_dataset_for_transforms_v2
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)
for wrapped_sample in dataloader:
assert tree_any(
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
)
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor.
......
......@@ -183,6 +183,10 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
), "Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(target_type="category") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech256
......@@ -190,7 +194,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories"
categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter"))
categories = ((1, "ak47"), (2, "american-flag"), (3, "backpack"))
num_images_per_category = 2
for idx, category in categories:
......@@ -258,6 +262,10 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
return split_to_num_examples[config["split"]]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Cityscapes
......@@ -382,6 +390,11 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
assert isinstance(polygon_img, PIL.Image.Image)
(polygon_target, info["expected_polygon_target"])
def test_transforms_v2_wrapper_spawn(self):
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageNet
......@@ -413,6 +426,10 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
torch.save((wnid_to_classes, None), tmpdir / "meta.bin")
return num_examples
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CIFAR10
......@@ -607,6 +624,11 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
assert merged_imgs_names == all_imgs_names
def test_transforms_v2_wrapper_spawn(self):
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.VOCSegmentation
......@@ -694,6 +716,10 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
return data
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class VOCDetectionTestCase(VOCSegmentationTestCase):
DATASET_CLASS = datasets.VOCDetection
......@@ -714,6 +740,10 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
assert object == info["annotation"]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CocoDetection
......@@ -784,6 +814,10 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
json.dump(content, fh)
return file
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CocoCaptionsTestCase(CocoDetectionTestCase):
DATASET_CLASS = datasets.CocoCaptions
......@@ -800,6 +834,11 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
_, captions = dataset[0]
assert tuple(captions) == tuple(info["captions"])
def test_transforms_v2_wrapper_spawn(self):
# We need to define this method, because otherwise the test from the super class will
# be run
pytest.skip("CocoCaptions is currently not supported by the v2 wrapper.")
class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.UCF101
......@@ -966,6 +1005,10 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
)
return num_videos_per_class * len(classes)
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(output_format="TCHW") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.HMDB51
......@@ -1193,6 +1236,10 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
def _file_stem(self, idx):
return f"2008_{idx:06d}"
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(mode="segmentation") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FakeData
......@@ -1642,6 +1689,10 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase):
return split_to_num_examples[config["train"]]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SVHN
......@@ -2516,6 +2567,10 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
breed_id = "-1"
return (image_id, class_id, species, breed_id)
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.StanfordCars
......
import itertools
import pathlib
import pickle
import random
import warnings
......@@ -169,8 +170,11 @@ class TestSmoke:
next(make_vanilla_tensor_images()),
],
)
@pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_common(self, transform, adapter, container_type, image_or_video, device):
def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device):
transform = de_serialize(transform)
canvas_size = F.get_size(image_or_video)
input = dict(
image_or_video=image_or_video,
......
......@@ -2,6 +2,7 @@ import contextlib
import decimal
import inspect
import math
import pickle
import re
from pathlib import Path
from unittest import mock
......@@ -247,6 +248,8 @@ def _check_transform_v1_compatibility(transform, input):
def check_transform(transform_cls, input, *args, **kwargs):
transform = transform_cls(*args, **kwargs)
pickle.loads(pickle.dumps(transform))
output = transform(input)
assert isinstance(output, type(input))
......
......@@ -162,6 +162,7 @@ class VisionDatasetDatapointWrapper:
raise TypeError(msg)
self._dataset = dataset
self._target_keys = target_keys
self._wrapper = wrapper_factory(dataset, target_keys)
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
......@@ -197,6 +198,9 @@ class VisionDatasetDatapointWrapper:
def __len__(self):
return len(self._dataset)
def __reduce__(self):
return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys)
def raise_not_supported(description):
raise RuntimeError(
......
......@@ -137,13 +137,13 @@ class WIDERFace(VisionDataset):
{
"img_path": img_path,
"annotations": {
"bbox": labels_tensor[:, 0:4], # x, y, width, height
"blur": labels_tensor[:, 4],
"expression": labels_tensor[:, 5],
"illumination": labels_tensor[:, 6],
"occlusion": labels_tensor[:, 7],
"pose": labels_tensor[:, 8],
"invalid": labels_tensor[:, 9],
"bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height
"blur": labels_tensor[:, 4].clone(),
"expression": labels_tensor[:, 5].clone(),
"illumination": labels_tensor[:, 6].clone(),
"occlusion": labels_tensor[:, 7].clone(),
"pose": labels_tensor[:, 8].clone(),
"invalid": labels_tensor[:, 9].clone(),
},
}
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment