Unverified Commit 054432d2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

enforce pickleability for v2 transforms and wrapped datasets (#7860)

parent 92882b69
...@@ -5,6 +5,7 @@ import inspect ...@@ -5,6 +5,7 @@ import inspect
import itertools import itertools
import os import os
import pathlib import pathlib
import platform
import random import random
import shutil import shutil
import string import string
...@@ -548,7 +549,7 @@ class DatasetTestCase(unittest.TestCase): ...@@ -548,7 +549,7 @@ class DatasetTestCase(unittest.TestCase):
@test_all_configs @test_all_configs
def test_num_examples(self, config): def test_num_examples(self, config):
with self.create_dataset(config) as (dataset, info): with self.create_dataset(config) as (dataset, info):
assert len(dataset) == info["num_examples"] assert len(list(dataset)) == len(dataset) == info["num_examples"]
@test_all_configs @test_all_configs
def test_transforms(self, config): def test_transforms(self, config):
...@@ -692,6 +693,31 @@ class VideoDatasetTestCase(DatasetTestCase): ...@@ -692,6 +693,31 @@ class VideoDatasetTestCase(DatasetTestCase):
super().test_transforms_v2_wrapper.__wrapped__(self, config) super().test_transforms_v2_wrapper.__wrapped__(self, config)
def _no_collate(batch):
return batch
def check_transforms_v2_wrapper_spawn(dataset):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
# we are enforcing here.
if platform.system() != "Darwin":
pytest.skip("Multiprocessing spawning is only checked on macOS.")
from torch.utils.data import DataLoader
from torchvision import datapoints
from torchvision.datasets import wrap_dataset_for_transforms_v2
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)
for wrapped_sample in dataloader:
assert tree_any(
lambda item: isinstance(item, (datapoints.Image, datapoints.Video, PIL.Image.Image)), wrapped_sample
)
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor. r"""Create a random uint8 tensor.
......
...@@ -183,6 +183,10 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -183,6 +183,10 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
), "Type of the combined target does not match the type of the corresponding individual target: " ), "Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}", f"{actual} is not {expected}",
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(target_type="category") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech256 DATASET_CLASS = datasets.Caltech256
...@@ -190,7 +194,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -190,7 +194,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories" tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories"
categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter")) categories = ((1, "ak47"), (2, "american-flag"), (3, "backpack"))
num_images_per_category = 2 num_images_per_category = 2
for idx, category in categories: for idx, category in categories:
...@@ -258,6 +262,10 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -258,6 +262,10 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
return split_to_num_examples[config["split"]] return split_to_num_examples[config["split"]]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Cityscapes DATASET_CLASS = datasets.Cityscapes
...@@ -382,6 +390,11 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -382,6 +390,11 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
assert isinstance(polygon_img, PIL.Image.Image) assert isinstance(polygon_img, PIL.Image.Image)
(polygon_target, info["expected_polygon_target"]) (polygon_target, info["expected_polygon_target"])
def test_transforms_v2_wrapper_spawn(self):
for target_type in ["instance", "semantic", ["instance", "semantic"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageNet DATASET_CLASS = datasets.ImageNet
...@@ -413,6 +426,10 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -413,6 +426,10 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
torch.save((wnid_to_classes, None), tmpdir / "meta.bin") torch.save((wnid_to_classes, None), tmpdir / "meta.bin")
return num_examples return num_examples
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CIFAR10 DATASET_CLASS = datasets.CIFAR10
...@@ -607,6 +624,11 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase): ...@@ -607,6 +624,11 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
assert merged_imgs_names == all_imgs_names assert merged_imgs_names == all_imgs_names
def test_transforms_v2_wrapper_spawn(self):
for target_type in ["identity", "bbox", ["identity", "bbox"]]:
with self.create_dataset(target_type=target_type) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.VOCSegmentation DATASET_CLASS = datasets.VOCSegmentation
...@@ -694,6 +716,10 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -694,6 +716,10 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
return data return data
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class VOCDetectionTestCase(VOCSegmentationTestCase): class VOCDetectionTestCase(VOCSegmentationTestCase):
DATASET_CLASS = datasets.VOCDetection DATASET_CLASS = datasets.VOCDetection
...@@ -714,6 +740,10 @@ class VOCDetectionTestCase(VOCSegmentationTestCase): ...@@ -714,6 +740,10 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
assert object == info["annotation"] assert object == info["annotation"]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CocoDetection DATASET_CLASS = datasets.CocoDetection
...@@ -784,6 +814,10 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -784,6 +814,10 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
json.dump(content, fh) json.dump(content, fh)
return file return file
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class CocoCaptionsTestCase(CocoDetectionTestCase): class CocoCaptionsTestCase(CocoDetectionTestCase):
DATASET_CLASS = datasets.CocoCaptions DATASET_CLASS = datasets.CocoCaptions
...@@ -800,6 +834,11 @@ class CocoCaptionsTestCase(CocoDetectionTestCase): ...@@ -800,6 +834,11 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
_, captions = dataset[0] _, captions = dataset[0]
assert tuple(captions) == tuple(info["captions"]) assert tuple(captions) == tuple(info["captions"])
def test_transforms_v2_wrapper_spawn(self):
# We need to define this method, because otherwise the test from the super class will
# be run
pytest.skip("CocoCaptions is currently not supported by the v2 wrapper.")
class UCF101TestCase(datasets_utils.VideoDatasetTestCase): class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.UCF101 DATASET_CLASS = datasets.UCF101
...@@ -966,6 +1005,10 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase): ...@@ -966,6 +1005,10 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
) )
return num_videos_per_class * len(classes) return num_videos_per_class * len(classes)
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(output_format="TCHW") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.HMDB51 DATASET_CLASS = datasets.HMDB51
...@@ -1193,6 +1236,10 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1193,6 +1236,10 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
def _file_stem(self, idx): def _file_stem(self, idx):
return f"2008_{idx:06d}" return f"2008_{idx:06d}"
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(mode="segmentation") as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class FakeDataTestCase(datasets_utils.ImageDatasetTestCase): class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FakeData DATASET_CLASS = datasets.FakeData
...@@ -1642,6 +1689,10 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1642,6 +1689,10 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase):
return split_to_num_examples[config["train"]] return split_to_num_examples[config["train"]]
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class SvhnTestCase(datasets_utils.ImageDatasetTestCase): class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SVHN DATASET_CLASS = datasets.SVHN
...@@ -2516,6 +2567,10 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2516,6 +2567,10 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
breed_id = "-1" breed_id = "-1"
return (image_id, class_id, species, breed_id) return (image_id, class_id, species, breed_id)
def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.StanfordCars DATASET_CLASS = datasets.StanfordCars
......
import itertools import itertools
import pathlib import pathlib
import pickle
import random import random
import warnings import warnings
...@@ -169,8 +170,11 @@ class TestSmoke: ...@@ -169,8 +170,11 @@ class TestSmoke:
next(make_vanilla_tensor_images()), next(make_vanilla_tensor_images()),
], ],
) )
@pytest.mark.parametrize("de_serialize", [lambda t: t, lambda t: pickle.loads(pickle.dumps(t))])
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_common(self, transform, adapter, container_type, image_or_video, device): def test_common(self, transform, adapter, container_type, image_or_video, de_serialize, device):
transform = de_serialize(transform)
canvas_size = F.get_size(image_or_video) canvas_size = F.get_size(image_or_video)
input = dict( input = dict(
image_or_video=image_or_video, image_or_video=image_or_video,
......
...@@ -2,6 +2,7 @@ import contextlib ...@@ -2,6 +2,7 @@ import contextlib
import decimal import decimal
import inspect import inspect
import math import math
import pickle
import re import re
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
...@@ -247,6 +248,8 @@ def _check_transform_v1_compatibility(transform, input): ...@@ -247,6 +248,8 @@ def _check_transform_v1_compatibility(transform, input):
def check_transform(transform_cls, input, *args, **kwargs): def check_transform(transform_cls, input, *args, **kwargs):
transform = transform_cls(*args, **kwargs) transform = transform_cls(*args, **kwargs)
pickle.loads(pickle.dumps(transform))
output = transform(input) output = transform(input)
assert isinstance(output, type(input)) assert isinstance(output, type(input))
......
...@@ -162,6 +162,7 @@ class VisionDatasetDatapointWrapper: ...@@ -162,6 +162,7 @@ class VisionDatasetDatapointWrapper:
raise TypeError(msg) raise TypeError(msg)
self._dataset = dataset self._dataset = dataset
self._target_keys = target_keys
self._wrapper = wrapper_factory(dataset, target_keys) self._wrapper = wrapper_factory(dataset, target_keys)
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them. # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
...@@ -197,6 +198,9 @@ class VisionDatasetDatapointWrapper: ...@@ -197,6 +198,9 @@ class VisionDatasetDatapointWrapper:
def __len__(self): def __len__(self):
return len(self._dataset) return len(self._dataset)
def __reduce__(self):
return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys)
def raise_not_supported(description): def raise_not_supported(description):
raise RuntimeError( raise RuntimeError(
......
...@@ -137,13 +137,13 @@ class WIDERFace(VisionDataset): ...@@ -137,13 +137,13 @@ class WIDERFace(VisionDataset):
{ {
"img_path": img_path, "img_path": img_path,
"annotations": { "annotations": {
"bbox": labels_tensor[:, 0:4], # x, y, width, height "bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height
"blur": labels_tensor[:, 4], "blur": labels_tensor[:, 4].clone(),
"expression": labels_tensor[:, 5], "expression": labels_tensor[:, 5].clone(),
"illumination": labels_tensor[:, 6], "illumination": labels_tensor[:, 6].clone(),
"occlusion": labels_tensor[:, 7], "occlusion": labels_tensor[:, 7].clone(),
"pose": labels_tensor[:, 8], "pose": labels_tensor[:, 8].clone(),
"invalid": labels_tensor[:, 9], "invalid": labels_tensor[:, 9].clone(),
}, },
} }
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment