Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b80bdb75
Unverified
Commit
b80bdb75
authored
Oct 27, 2023
by
Nicolas Hug
Committed by
GitHub
Oct 27, 2023
Browse files
Fix v2 transforms in spawn mp context (#8067)
parent
96d2ce91
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
40 deletions
+70
-40
test/datasets_utils.py
test/datasets_utils.py
+19
-15
test/test_datasets.py
test/test_datasets.py
+38
-24
torchvision/tv_tensors/_dataset_wrapper.py
torchvision/tv_tensors/_dataset_wrapper.py
+13
-1
No files found.
test/datasets_utils.py
View file @
b80bdb75
...
@@ -27,7 +27,11 @@ import torchvision.datasets
...
@@ -27,7 +27,11 @@ import torchvision.datasets
import
torchvision.io
import
torchvision.io
from
common_utils
import
disable_console_output
,
get_tmp_dir
from
common_utils
import
disable_console_output
,
get_tmp_dir
from
torch.utils._pytree
import
tree_any
from
torch.utils._pytree
import
tree_any
from
torch.utils.data
import
DataLoader
from
torchvision
import
tv_tensors
from
torchvision.datasets
import
wrap_dataset_for_transforms_v2
from
torchvision.transforms.functional
import
get_dimensions
from
torchvision.transforms.functional
import
get_dimensions
from
torchvision.transforms.v2.functional
import
get_size
__all__
=
[
__all__
=
[
...
@@ -568,9 +572,6 @@ class DatasetTestCase(unittest.TestCase):
...
@@ -568,9 +572,6 @@ class DatasetTestCase(unittest.TestCase):
@
test_all_configs
@
test_all_configs
def
test_transforms_v2_wrapper
(
self
,
config
):
def
test_transforms_v2_wrapper
(
self
,
config
):
from
torchvision
import
tv_tensors
from
torchvision.datasets
import
wrap_dataset_for_transforms_v2
try
:
try
:
with
self
.
create_dataset
(
config
)
as
(
dataset
,
info
):
with
self
.
create_dataset
(
config
)
as
(
dataset
,
info
):
for
target_keys
in
[
None
,
"all"
]:
for
target_keys
in
[
None
,
"all"
]:
...
@@ -709,26 +710,29 @@ def _no_collate(batch):
...
@@ -709,26 +710,29 @@ def _no_collate(batch):
return
batch
return
batch
def
check_transforms_v2_wrapper_spawn
(
dataset
):
def
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
# This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader.
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
# We also check that transforms are applied correctly as a non-regression test for
# we are enforcing here.
# https://github.com/pytorch/vision/issues/8066
if
platform
.
system
()
!=
"Darwin"
:
# Implicitly, this also checks that the wrapped datasets are pickleable.
pytest
.
skip
(
"Multiprocessing spawning is only checked on macOS."
)
from
torch.utils.data
import
DataLoader
# To save CI/test time, we only check on Windows where "spawn" is the default
from
torchvision
import
tv_tensors
if
platform
.
system
()
!=
"Windows"
:
from
torchvision.datasets
import
wrap_dataset_for_transforms_v2
pytest
.
skip
(
"Multiprocessing spawning is only checked on macOS."
)
wrapped_dataset
=
wrap_dataset_for_transforms_v2
(
dataset
)
wrapped_dataset
=
wrap_dataset_for_transforms_v2
(
dataset
)
dataloader
=
DataLoader
(
wrapped_dataset
,
num_workers
=
2
,
multiprocessing_context
=
"spawn"
,
collate_fn
=
_no_collate
)
dataloader
=
DataLoader
(
wrapped_dataset
,
num_workers
=
2
,
multiprocessing_context
=
"spawn"
,
collate_fn
=
_no_collate
)
for
wrapped_sample
in
dataloader
:
def
resize_was_applied
(
item
):
assert
tree_any
(
# Checking the size of the output ensures that the Resize transform was correctly applied
lambda
item
:
isinstance
(
item
,
(
tv_tensors
.
Image
,
tv_tensors
.
Video
,
PIL
.
Image
.
Image
)),
wrapped_sample
return
isinstance
(
item
,
(
tv_tensors
.
Image
,
tv_tensors
.
Video
,
PIL
.
Image
.
Image
))
and
get_size
(
item
)
==
list
(
expected_size
)
)
for
wrapped_sample
in
dataloader
:
assert
tree_any
(
resize_was_applied
,
wrapped_sample
)
def
create_image_or_video_tensor
(
size
:
Sequence
[
int
])
->
torch
.
Tensor
:
def
create_image_or_video_tensor
(
size
:
Sequence
[
int
])
->
torch
.
Tensor
:
r
"""Create a random uint8 tensor.
r
"""Create a random uint8 tensor.
...
...
test/test_datasets.py
View file @
b80bdb75
...
@@ -24,6 +24,7 @@ import torch
...
@@ -24,6 +24,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
common_utils
import
combinations_grid
from
common_utils
import
combinations_grid
from
torchvision
import
datasets
from
torchvision
import
datasets
from
torchvision.transforms
import
v2
class
STL10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
STL10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
@@ -184,8 +185,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -184,8 +185,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
f
"
{
actual
}
is not
{
expected
}
"
,
f
"
{
actual
}
is not
{
expected
}
"
,
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
(
target_type
=
"category"
)
as
(
dataset
,
_
):
expected_size
=
(
123
,
321
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
target_type
=
"category"
,
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
Caltech256TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
Caltech256TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
@@ -263,8 +265,9 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -263,8 +265,9 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
return
split_to_num_examples
[
config
[
"split"
]]
return
split_to_num_examples
[
config
[
"split"
]]
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
expected_size
=
(
123
,
321
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
CityScapesTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
CityScapesTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
@@ -391,9 +394,10 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -391,9 +394,10 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
(
polygon_target
,
info
[
"expected_polygon_target"
])
(
polygon_target
,
info
[
"expected_polygon_target"
])
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
expected_size
=
(
123
,
321
)
for
target_type
in
[
"instance"
,
"semantic"
,
[
"instance"
,
"semantic"
]]:
for
target_type
in
[
"instance"
,
"semantic"
,
[
"instance"
,
"semantic"
]]:
with
self
.
create_dataset
(
target_type
=
target_type
)
as
(
dataset
,
_
):
with
self
.
create_dataset
(
target_type
=
target_type
,
transform
=
v2
.
Resize
(
size
=
expected_size
)
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
ImageNetTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
ImageNetTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
@@ -427,8 +431,9 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -427,8 +431,9 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
return
num_examples
return
num_examples
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
expected_size
=
(
123
,
321
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
CIFAR10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
CIFAR10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
@@ -625,9 +630,10 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -625,9 +630,10 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
assert
merged_imgs_names
==
all_imgs_names
assert
merged_imgs_names
==
all_imgs_names
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
expected_size
=
(
123
,
321
)
for
target_type
in
[
"identity"
,
"bbox"
,
[
"identity"
,
"bbox"
]]:
for
target_type
in
[
"identity"
,
"bbox"
,
[
"identity"
,
"bbox"
]]:
with
self
.
create_dataset
(
target_type
=
target_type
)
as
(
dataset
,
_
):
with
self
.
create_dataset
(
target_type
=
target_type
,
transform
=
v2
.
Resize
(
size
=
expected_size
)
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
VOCSegmentationTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
VOCSegmentationTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
@@ -717,8 +723,9 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -717,8 +723,9 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
return
data
return
data
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
expected_size
=
(
123
,
321
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
VOCDetectionTestCase
(
VOCSegmentationTestCase
):
class
VOCDetectionTestCase
(
VOCSegmentationTestCase
):
...
@@ -741,8 +748,9 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
...
@@ -741,8 +748,9 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
assert
object
==
info
[
"annotation"
]
assert
object
==
info
[
"annotation"
]
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
expected_size
=
(
123
,
321
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
CocoDetectionTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
CocoDetectionTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
@@ -815,8 +823,9 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -815,8 +823,9 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
return
file
return
file
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
expected_size
=
(
123
,
321
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
CocoCaptionsTestCase
(
CocoDetectionTestCase
):
class
CocoCaptionsTestCase
(
CocoDetectionTestCase
):
...
@@ -1005,9 +1014,11 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
...
@@ -1005,9 +1014,11 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
)
)
return
num_videos_per_class
*
len
(
classes
)
return
num_videos_per_class
*
len
(
classes
)
@
pytest
.
mark
.
xfail
(
reason
=
"FIXME"
)
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
(
output_format
=
"TCHW"
)
as
(
dataset
,
_
):
expected_size
=
(
123
,
321
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
output_format
=
"TCHW"
,
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
HMDB51TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
class
HMDB51TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
...
@@ -1237,8 +1248,9 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1237,8 +1248,9 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
return
f
"2008_
{
idx
:
06
d
}
"
return
f
"2008_
{
idx
:
06
d
}
"
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
(
mode
=
"segmentation"
)
as
(
dataset
,
_
):
expected_size
=
(
123
,
321
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
mode
=
"segmentation"
,
transforms
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
FakeDataTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
FakeDataTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
@@ -1690,8 +1702,9 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1690,8 +1702,9 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase):
return
split_to_num_examples
[
config
[
"train"
]]
return
split_to_num_examples
[
config
[
"train"
]]
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
expected_size
=
(
123
,
321
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
SvhnTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
SvhnTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
@@ -2568,8 +2581,9 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -2568,8 +2581,9 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
return
(
image_id
,
class_id
,
species
,
breed_id
)
return
(
image_id
,
class_id
,
species
,
breed_id
)
def
test_transforms_v2_wrapper_spawn
(
self
):
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
expected_size
=
(
123
,
321
)
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
with
self
.
create_dataset
(
transform
=
v2
.
Resize
(
size
=
expected_size
))
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
,
expected_size
=
expected_size
)
class
StanfordCarsTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
StanfordCarsTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
torchvision/tv_tensors/_dataset_wrapper.py
View file @
b80bdb75
...
@@ -6,6 +6,7 @@ import collections.abc
...
@@ -6,6 +6,7 @@ import collections.abc
import
contextlib
import
contextlib
from
collections
import
defaultdict
from
collections
import
defaultdict
from
copy
import
copy
import
torch
import
torch
...
@@ -198,8 +199,19 @@ class VisionDatasetTVTensorWrapper:
...
@@ -198,8 +199,19 @@ class VisionDatasetTVTensorWrapper:
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
_dataset
)
return
len
(
self
.
_dataset
)
# TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs.
def
__reduce__
(
self
):
def
__reduce__
(
self
):
return
wrap_dataset_for_transforms_v2
,
(
self
.
_dataset
,
self
.
_target_keys
)
# __reduce__ gets called when we try to pickle the dataset.
# In a DataLoader with spawn context, this gets called `num_workers` times from the main process.
# We have to reset the [target_]transform[s] attributes of the dataset
# to their original values, because we previously set them to None in __init__().
dataset
=
copy
(
self
.
_dataset
)
dataset
.
transform
=
self
.
transform
dataset
.
transforms
=
self
.
transforms
dataset
.
target_transform
=
self
.
target_transform
return
wrap_dataset_for_transforms_v2
,
(
dataset
,
self
.
_target_keys
)
def
raise_not_supported
(
description
):
def
raise_not_supported
(
description
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment