Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b80bdb75
Unverified
Commit
b80bdb75
authored
Oct 27, 2023
by
Nicolas Hug
Committed by
GitHub
Oct 27, 2023
Browse files
Fix v2 transforms in spawn mp context (#8067)
parent
96d2ce91
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
40 deletions
+70
-40
test/datasets_utils.py
test/datasets_utils.py
+19
-15
test/test_datasets.py
test/test_datasets.py
+38
-24
torchvision/tv_tensors/_dataset_wrapper.py
torchvision/tv_tensors/_dataset_wrapper.py
+13
-1
No files found.
test/datasets_utils.py
View file @
b80bdb75
...
...
@@ -27,7 +27,11 @@ import torchvision.datasets
import
torchvision.io
from
common_utils
import
disable_console_output
,
get_tmp_dir
from
torch.utils._pytree
import
tree_any
from
torch.utils.data
import
DataLoader
from
torchvision
import
tv_tensors
from
torchvision.datasets
import
wrap_dataset_for_transforms_v2
from
torchvision.transforms.functional
import
get_dimensions
from
torchvision.transforms.v2.functional
import
get_size
__all__
=
[
...
...
@@ -568,9 +572,6 @@ class DatasetTestCase(unittest.TestCase):
@
test_all_configs
def
test_transforms_v2_wrapper
(
self
,
config
):
from
torchvision
import
tv_tensors
from
torchvision.datasets
import
wrap_dataset_for_transforms_v2
try
:
with
self
.
create_dataset
(
config
)
as
(
dataset
,
info
):
for
target_keys
in
[
None
,
"all"
]:
...
...
@@ -709,26 +710,29 @@ def _no_collate(batch):
return
batch
def
check_transforms_v2_wrapper_spawn
(
dataset
):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
# we are enforcing here.
if
platform
.
system
()
!=
"Darwin"
:
pytest
.
skip
(
"Multiprocessing spawning is only checked on macOS."
)
def
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
):
# This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader.
# We also check that transforms are applied correctly as a non-regression test for
# https://github.com/pytorch/vision/issues/8066
# Implicitly, this also checks that the wrapped datasets are pickleable.
from
torch.utils.data
import
DataLoader
from
torchvision
import
tv_tensors
from
torchvision.datasets
import
wrap_dataset_for_transforms_v2
# To save CI/test time, we only check on Windows where "spawn" is the default
if
platform
.
system
()
!=
"Windows"
:
pytest
.
skip
(
"Multiprocessing spawning is only checked on macOS."
)
wrapped_dataset
=
wrap_dataset_for_transforms_v2
(
dataset
)
dataloader
=
DataLoader
(
wrapped_dataset
,
num_workers
=
2
,
multiprocessing_context
=
"spawn"
,
collate_fn
=
_no_collate
)
for
wrapped_sample
in
dataloader
:
assert
tree_any
(
lambda
item
:
isinstance
(
item
,
(
tv_tensors
.
Image
,
tv_tensors
.
Video
,
PIL
.
Image
.
Image
)),
wrapped_sample
def
resize_was_applied
(
item
):
# Checking the size of the output ensures that the Resize transform was correctly applied
return
isinstance
(
item
,
(
tv_tensors
.
Image
,
tv_tensors
.
Video
,
PIL
.
Image
.
Image
))
and
get_size
(
item
)
==
list
(
expected_size
)
for
wrapped_sample
in
dataloader
:
assert
tree_any
(
resize_was_applied
,
wrapped_sample
)
def
create_image_or_video_tensor
(
size
:
Sequence
[
int
])
->
torch
.
Tensor
:
r
"""Create a random uint8 tensor.
...
...
test/test_datasets.py
View file @
b80bdb75
...
...
@@ -24,6 +24,7 @@ import torch
import
torch.nn.functional
as
F
from
common_utils
import
combinations_grid
from
torchvision
import
datasets
from
torchvision.transforms
import
v2
class
STL10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
@@ -184,8 +185,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
f
"
{
actual
}
is not
{
expected
}
"
,
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
(
target_type
=
"category"
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
expected_size
=
(
123
,
321
)
with
self
.
create_dataset
(
target_type
=
"category"
,
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
Caltech256TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
@@ -263,8 +265,9 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
return
split_to_num_examples
[
config
[
"split"
]]
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
expected_size
=
(
123
,
321
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
CityScapesTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
@@ -391,9 +394,10 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
(
polygon_target
,
info
[
"expected_polygon_target"
])
def
test_transforms_v2_wrapper_spawn
(
self
):
expected_size
=
(
123
,
321
)
for
target_type
in
[
"instance"
,
"semantic"
,
[
"instance"
,
"semantic"
]]:
with
self
.
create_dataset
(
target_type
=
target_type
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
target_type
=
target_type
,
transform
=
v2
.
Resize
(
size
=
expected_size
)
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
ImageNetTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
@@ -427,8 +431,9 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
return
num_examples
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
expected_size
=
(
123
,
321
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
CIFAR10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
@@ -625,9 +630,10 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
assert
merged_imgs_names
==
all_imgs_names
def
test_transforms_v2_wrapper_spawn
(
self
):
expected_size
=
(
123
,
321
)
for
target_type
in
[
"identity"
,
"bbox"
,
[
"identity"
,
"bbox"
]]:
with
self
.
create_dataset
(
target_type
=
target_type
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
target_type
=
target_type
,
transform
=
v2
.
Resize
(
size
=
expected_size
)
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
VOCSegmentationTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
@@ -717,8 +723,9 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
return
data
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
expected_size
=
(
123
,
321
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
VOCDetectionTestCase
(
VOCSegmentationTestCase
):
...
...
@@ -741,8 +748,9 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
assert
object
==
info
[
"annotation"
]
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
expected_size
=
(
123
,
321
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
CocoDetectionTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
@@ -815,8 +823,9 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
return
file
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
expected_size
=
(
123
,
321
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
CocoCaptionsTestCase
(
CocoDetectionTestCase
):
...
...
@@ -1005,9 +1014,11 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
)
return
num_videos_per_class
*
len
(
classes
)
@
pytest
.
mark
.
xfail
(
reason
=
"FIXME"
)
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
(
output_format
=
"TCHW"
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
expected_size
=
(
123
,
321
)
with
self
.
create_dataset
(
output_format
=
"TCHW"
,
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
HMDB51TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
...
...
@@ -1237,8 +1248,9 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
return
f
"2008_
{
idx
:
06
d
}
"
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
(
mode
=
"segmentation"
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
expected_size
=
(
123
,
321
)
with
self
.
create_dataset
(
mode
=
"segmentation"
,
transforms
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
FakeDataTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
@@ -1690,8 +1702,9 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase):
return
split_to_num_examples
[
config
[
"train"
]]
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
expected_size
=
(
123
,
321
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
SvhnTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
@@ -2568,8 +2581,9 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
return
(
image_id
,
class_id
,
species
,
breed_id
)
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
expected_size
=
(
123
,
321
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
StanfordCarsTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
torchvision/tv_tensors/_dataset_wrapper.py
View file @
b80bdb75
...
...
@@ -6,6 +6,7 @@ import collections.abc
import
contextlib
from
collections
import
defaultdict
from
copy
import
copy
import
torch
...
...
@@ -198,8 +199,19 @@ class VisionDatasetTVTensorWrapper:
def
__len__
(
self
):
return
len
(
self
.
_dataset
)
# TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs.
def
__reduce__
(
self
):
return
wrap_dataset_for_transforms_v2
,
(
self
.
_dataset
,
self
.
_target_keys
)
# __reduce__ gets called when we try to pickle the dataset.
# In a DataLoader with spawn context, this gets called `num_workers` times from the main process.
# We have to reset the [target_]transform[s] attributes of the dataset
# to their original values, because we previously set them to None in __init__().
dataset
=
copy
(
self
.
_dataset
)
dataset
.
transform
=
self
.
transform
dataset
.
transforms
=
self
.
transforms
dataset
.
target_transform
=
self
.
target_transform
return
wrap_dataset_for_transforms_v2
,
(
dataset
,
self
.
_target_keys
)
def
raise_not_supported
(
description
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment