"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "0c445130ba3416fdd9e1ad0fccd312b0f574e57a"
Unverified Commit b80bdb75 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Fix v2 transforms in spawn mp context (#8067)

parent 96d2ce91
...@@ -27,7 +27,11 @@ import torchvision.datasets ...@@ -27,7 +27,11 @@ import torchvision.datasets
import torchvision.io import torchvision.io
from common_utils import disable_console_output, get_tmp_dir from common_utils import disable_console_output, get_tmp_dir
from torch.utils._pytree import tree_any from torch.utils._pytree import tree_any
from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
from torchvision.transforms.functional import get_dimensions from torchvision.transforms.functional import get_dimensions
from torchvision.transforms.v2.functional import get_size
__all__ = [ __all__ = [
...@@ -568,9 +572,6 @@ class DatasetTestCase(unittest.TestCase): ...@@ -568,9 +572,6 @@ class DatasetTestCase(unittest.TestCase):
@test_all_configs @test_all_configs
def test_transforms_v2_wrapper(self, config): def test_transforms_v2_wrapper(self, config):
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
try: try:
with self.create_dataset(config) as (dataset, info): with self.create_dataset(config) as (dataset, info):
for target_keys in [None, "all"]: for target_keys in [None, "all"]:
...@@ -709,26 +710,29 @@ def _no_collate(batch): ...@@ -709,26 +710,29 @@ def _no_collate(batch):
return batch return batch
def check_transforms_v2_wrapper_spawn(dataset): def check_transforms_v2_wrapper_spawn(dataset, expected_size):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new # This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader.
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what # We also check that transforms are applied correctly as a non-regression test for
# we are enforcing here. # https://github.com/pytorch/vision/issues/8066
if platform.system() != "Darwin": # Implicitly, this also checks that the wrapped datasets are pickleable.
pytest.skip("Multiprocessing spawning is only checked on macOS.")
from torch.utils.data import DataLoader # To save CI/test time, we only check on Windows where "spawn" is the default
from torchvision import tv_tensors if platform.system() != "Windows":
from torchvision.datasets import wrap_dataset_for_transforms_v2 pytest.skip("Multiprocessing spawning is only checked on macOS.")
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate) dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)
for wrapped_sample in dataloader: def resize_was_applied(item):
assert tree_any( # Checking the size of the output ensures that the Resize transform was correctly applied
lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample return isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)) and get_size(item) == list(
expected_size
) )
for wrapped_sample in dataloader:
assert tree_any(resize_was_applied, wrapped_sample)
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor. r"""Create a random uint8 tensor.
......
...@@ -24,6 +24,7 @@ import torch ...@@ -24,6 +24,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from common_utils import combinations_grid from common_utils import combinations_grid
from torchvision import datasets from torchvision import datasets
from torchvision.transforms import v2
class STL10TestCase(datasets_utils.ImageDatasetTestCase): class STL10TestCase(datasets_utils.ImageDatasetTestCase):
...@@ -184,8 +185,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -184,8 +185,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
f"{actual} is not {expected}", f"{actual} is not {expected}",
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(target_type="category") as (dataset, _): expected_size = (123, 321)
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) with self.create_dataset(target_type="category", transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
...@@ -263,8 +265,9 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -263,8 +265,9 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
return split_to_num_examples[config["split"]] return split_to_num_examples[config["split"]]
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _): expected_size = (123, 321)
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -391,9 +394,10 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -391,9 +394,10 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
(polygon_target, info["expected_polygon_target"]) (polygon_target, info["expected_polygon_target"])
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
expected_size = (123, 321)
for target_type in ["instance", "semantic", ["instance", "semantic"]]: for target_type in ["instance", "semantic", ["instance", "semantic"]]:
with self.create_dataset(target_type=target_type) as (dataset, _): with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -427,8 +431,9 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -427,8 +431,9 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
return num_examples return num_examples
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _): expected_size = (123, 321)
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
...@@ -625,9 +630,10 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase): ...@@ -625,9 +630,10 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
assert merged_imgs_names == all_imgs_names assert merged_imgs_names == all_imgs_names
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
expected_size = (123, 321)
for target_type in ["identity", "bbox", ["identity", "bbox"]]: for target_type in ["identity", "bbox", ["identity", "bbox"]]:
with self.create_dataset(target_type=target_type) as (dataset, _): with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -717,8 +723,9 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -717,8 +723,9 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
return data return data
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _): expected_size = (123, 321)
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class VOCDetectionTestCase(VOCSegmentationTestCase): class VOCDetectionTestCase(VOCSegmentationTestCase):
...@@ -741,8 +748,9 @@ class VOCDetectionTestCase(VOCSegmentationTestCase): ...@@ -741,8 +748,9 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
assert object == info["annotation"] assert object == info["annotation"]
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _): expected_size = (123, 321)
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -815,8 +823,9 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -815,8 +823,9 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
return file return file
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _): expected_size = (123, 321)
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class CocoCaptionsTestCase(CocoDetectionTestCase): class CocoCaptionsTestCase(CocoDetectionTestCase):
...@@ -1005,9 +1014,11 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase): ...@@ -1005,9 +1014,11 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
) )
return num_videos_per_class * len(classes) return num_videos_per_class * len(classes)
@pytest.mark.xfail(reason="FIXME")
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(output_format="TCHW") as (dataset, _): expected_size = (123, 321)
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) with self.create_dataset(output_format="TCHW", transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
...@@ -1237,8 +1248,9 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1237,8 +1248,9 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
return f"2008_{idx:06d}" return f"2008_{idx:06d}"
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset(mode="segmentation") as (dataset, _): expected_size = (123, 321)
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) with self.create_dataset(mode="segmentation", transforms=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class FakeDataTestCase(datasets_utils.ImageDatasetTestCase): class FakeDataTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -1690,8 +1702,9 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1690,8 +1702,9 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase):
return split_to_num_examples[config["train"]] return split_to_num_examples[config["train"]]
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _): expected_size = (123, 321)
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class SvhnTestCase(datasets_utils.ImageDatasetTestCase): class SvhnTestCase(datasets_utils.ImageDatasetTestCase):
...@@ -2568,8 +2581,9 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2568,8 +2581,9 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
return (image_id, class_id, species, breed_id) return (image_id, class_id, species, breed_id)
def test_transforms_v2_wrapper_spawn(self): def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _): expected_size = (123, 321)
datasets_utils.check_transforms_v2_wrapper_spawn(dataset) with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size)
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
......
...@@ -6,6 +6,7 @@ import collections.abc ...@@ -6,6 +6,7 @@ import collections.abc
import contextlib import contextlib
from collections import defaultdict from collections import defaultdict
from copy import copy
import torch import torch
...@@ -198,8 +199,19 @@ class VisionDatasetTVTensorWrapper: ...@@ -198,8 +199,19 @@ class VisionDatasetTVTensorWrapper:
def __len__(self): def __len__(self):
return len(self._dataset) return len(self._dataset)
# TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs.
def __reduce__(self): def __reduce__(self):
return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys) # __reduce__ gets called when we try to pickle the dataset.
# In a DataLoader with spawn context, this gets called `num_workers` times from the main process.
# We have to reset the [target_]transform[s] attributes of the dataset
# to their original values, because we previously set them to None in __init__().
dataset = copy(self._dataset)
dataset.transform = self.transform
dataset.transforms = self.transforms
dataset.target_transform = self.target_transform
return wrap_dataset_for_transforms_v2, (dataset, self._target_keys)
def raise_not_supported(description): def raise_not_supported(description):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment