Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
054432d2
Unverified
Commit
054432d2
authored
Aug 24, 2023
by
Philip Meier
Committed by
GitHub
Aug 24, 2023
Browse files
enforce pickleability for v2 transforms and wrapped datasets (#7860)
parent
92882b69
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
102 additions
and
10 deletions
+102
-10
test/datasets_utils.py
test/datasets_utils.py
+27
-1
test/test_datasets.py
test/test_datasets.py
+56
-1
test/test_transforms_v2.py
test/test_transforms_v2.py
+5
-1
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+3
-0
torchvision/datapoints/_dataset_wrapper.py
torchvision/datapoints/_dataset_wrapper.py
+4
-0
torchvision/datasets/widerface.py
torchvision/datasets/widerface.py
+7
-7
No files found.
test/datasets_utils.py
View file @
054432d2
...
@@ -5,6 +5,7 @@ import inspect
...
@@ -5,6 +5,7 @@ import inspect
import
itertools
import
itertools
import
os
import
os
import
pathlib
import
pathlib
import
platform
import
random
import
random
import
shutil
import
shutil
import
string
import
string
...
@@ -548,7 +549,7 @@ class DatasetTestCase(unittest.TestCase):
...
@@ -548,7 +549,7 @@ class DatasetTestCase(unittest.TestCase):
@
test_all_configs
@
test_all_configs
def
test_num_examples
(
self
,
config
):
def
test_num_examples
(
self
,
config
):
with
self
.
create_dataset
(
config
)
as
(
dataset
,
info
):
with
self
.
create_dataset
(
config
)
as
(
dataset
,
info
):
assert
len
(
dataset
)
==
info
[
"num_examples"
]
assert
len
(
list
(
dataset
))
==
len
(
dataset
)
==
info
[
"num_examples"
]
@
test_all_configs
@
test_all_configs
def
test_transforms
(
self
,
config
):
def
test_transforms
(
self
,
config
):
...
@@ -692,6 +693,31 @@ class VideoDatasetTestCase(DatasetTestCase):
...
@@ -692,6 +693,31 @@ class VideoDatasetTestCase(DatasetTestCase):
super
().
test_transforms_v2_wrapper
.
__wrapped__
(
self
,
config
)
super
().
test_transforms_v2_wrapper
.
__wrapped__
(
self
,
config
)
def
_no_collate
(
batch
):
return
batch
def
check_transforms_v2_wrapper_spawn
(
dataset
):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
# we are enforcing here.
if
platform
.
system
()
!=
"Darwin"
:
pytest
.
skip
(
"Multiprocessing spawning is only checked on macOS."
)
from
torch.utils.data
import
DataLoader
from
torchvision
import
datapoints
from
torchvision.datasets
import
wrap_dataset_for_transforms_v2
wrapped_dataset
=
wrap_dataset_for_transforms_v2
(
dataset
)
dataloader
=
DataLoader
(
wrapped_dataset
,
num_workers
=
2
,
multiprocessing_context
=
"spawn"
,
collate_fn
=
_no_collate
)
for
wrapped_sample
in
dataloader
:
assert
tree_any
(
lambda
item
:
isinstance
(
item
,
(
datapoints
.
Image
,
datapoints
.
Video
,
PIL
.
Image
.
Image
)),
wrapped_sample
)
def
create_image_or_video_tensor
(
size
:
Sequence
[
int
])
->
torch
.
Tensor
:
def
create_image_or_video_tensor
(
size
:
Sequence
[
int
])
->
torch
.
Tensor
:
r
"""Create a random uint8 tensor.
r
"""Create a random uint8 tensor.
...
...
test/test_datasets.py
View file @
054432d2
...
@@ -183,6 +183,10 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -183,6 +183,10 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
),
"Type of the combined target does not match the type of the corresponding individual target: "
),
"Type of the combined target does not match the type of the corresponding individual target: "
f
"
{
actual
}
is not
{
expected
}
"
,
f
"
{
actual
}
is not
{
expected
}
"
,
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
(
target_type
=
"category"
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
Caltech256TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
Caltech256TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
Caltech256
DATASET_CLASS
=
datasets
.
Caltech256
...
@@ -190,7 +194,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -190,7 +194,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
def
inject_fake_data
(
self
,
tmpdir
,
config
):
def
inject_fake_data
(
self
,
tmpdir
,
config
):
tmpdir
=
pathlib
.
Path
(
tmpdir
)
/
"caltech256"
/
"256_ObjectCategories"
tmpdir
=
pathlib
.
Path
(
tmpdir
)
/
"caltech256"
/
"256_ObjectCategories"
categories
=
((
1
,
"ak47"
),
(
127
,
"laptop-101"
),
(
257
,
"clutter
"
))
categories
=
((
1
,
"ak47"
),
(
2
,
"american-flag"
),
(
3
,
"backpack
"
))
num_images_per_category
=
2
num_images_per_category
=
2
for
idx
,
category
in
categories
:
for
idx
,
category
in
categories
:
...
@@ -258,6 +262,10 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -258,6 +262,10 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
return
split_to_num_examples
[
config
[
"split"
]]
return
split_to_num_examples
[
config
[
"split"
]]
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
CityScapesTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
CityScapesTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
Cityscapes
DATASET_CLASS
=
datasets
.
Cityscapes
...
@@ -382,6 +390,11 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -382,6 +390,11 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase):
assert
isinstance
(
polygon_img
,
PIL
.
Image
.
Image
)
assert
isinstance
(
polygon_img
,
PIL
.
Image
.
Image
)
(
polygon_target
,
info
[
"expected_polygon_target"
])
(
polygon_target
,
info
[
"expected_polygon_target"
])
def
test_transforms_v2_wrapper_spawn
(
self
):
for
target_type
in
[
"instance"
,
"semantic"
,
[
"instance"
,
"semantic"
]]:
with
self
.
create_dataset
(
target_type
=
target_type
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
ImageNetTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
ImageNetTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
ImageNet
DATASET_CLASS
=
datasets
.
ImageNet
...
@@ -413,6 +426,10 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -413,6 +426,10 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
torch
.
save
((
wnid_to_classes
,
None
),
tmpdir
/
"meta.bin"
)
torch
.
save
((
wnid_to_classes
,
None
),
tmpdir
/
"meta.bin"
)
return
num_examples
return
num_examples
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
CIFAR10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
CIFAR10TestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
CIFAR10
DATASET_CLASS
=
datasets
.
CIFAR10
...
@@ -607,6 +624,11 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -607,6 +624,11 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
assert
merged_imgs_names
==
all_imgs_names
assert
merged_imgs_names
==
all_imgs_names
def
test_transforms_v2_wrapper_spawn
(
self
):
for
target_type
in
[
"identity"
,
"bbox"
,
[
"identity"
,
"bbox"
]]:
with
self
.
create_dataset
(
target_type
=
target_type
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
VOCSegmentationTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
VOCSegmentationTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
VOCSegmentation
DATASET_CLASS
=
datasets
.
VOCSegmentation
...
@@ -694,6 +716,10 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -694,6 +716,10 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
return
data
return
data
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
VOCDetectionTestCase
(
VOCSegmentationTestCase
):
class
VOCDetectionTestCase
(
VOCSegmentationTestCase
):
DATASET_CLASS
=
datasets
.
VOCDetection
DATASET_CLASS
=
datasets
.
VOCDetection
...
@@ -714,6 +740,10 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
...
@@ -714,6 +740,10 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
assert
object
==
info
[
"annotation"
]
assert
object
==
info
[
"annotation"
]
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
CocoDetectionTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
CocoDetectionTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
CocoDetection
DATASET_CLASS
=
datasets
.
CocoDetection
...
@@ -784,6 +814,10 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -784,6 +814,10 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
json
.
dump
(
content
,
fh
)
json
.
dump
(
content
,
fh
)
return
file
return
file
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
CocoCaptionsTestCase
(
CocoDetectionTestCase
):
class
CocoCaptionsTestCase
(
CocoDetectionTestCase
):
DATASET_CLASS
=
datasets
.
CocoCaptions
DATASET_CLASS
=
datasets
.
CocoCaptions
...
@@ -800,6 +834,11 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
...
@@ -800,6 +834,11 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
_
,
captions
=
dataset
[
0
]
_
,
captions
=
dataset
[
0
]
assert
tuple
(
captions
)
==
tuple
(
info
[
"captions"
])
assert
tuple
(
captions
)
==
tuple
(
info
[
"captions"
])
def
test_transforms_v2_wrapper_spawn
(
self
):
# We need to define this method, because otherwise the test from the super class will
# be run
pytest
.
skip
(
"CocoCaptions is currently not supported by the v2 wrapper."
)
class
UCF101TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
class
UCF101TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
DATASET_CLASS
=
datasets
.
UCF101
DATASET_CLASS
=
datasets
.
UCF101
...
@@ -966,6 +1005,10 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
...
@@ -966,6 +1005,10 @@ class KineticsTestCase(datasets_utils.VideoDatasetTestCase):
)
)
return
num_videos_per_class
*
len
(
classes
)
return
num_videos_per_class
*
len
(
classes
)
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
(
output_format
=
"TCHW"
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
HMDB51TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
class
HMDB51TestCase
(
datasets_utils
.
VideoDatasetTestCase
):
DATASET_CLASS
=
datasets
.
HMDB51
DATASET_CLASS
=
datasets
.
HMDB51
...
@@ -1193,6 +1236,10 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1193,6 +1236,10 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
def
_file_stem
(
self
,
idx
):
def
_file_stem
(
self
,
idx
):
return
f
"2008_
{
idx
:
06
d
}
"
return
f
"2008_
{
idx
:
06
d
}
"
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
(
mode
=
"segmentation"
)
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
FakeDataTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
FakeDataTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
FakeData
DATASET_CLASS
=
datasets
.
FakeData
...
@@ -1642,6 +1689,10 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1642,6 +1689,10 @@ class KittiTestCase(datasets_utils.ImageDatasetTestCase):
return
split_to_num_examples
[
config
[
"train"
]]
return
split_to_num_examples
[
config
[
"train"
]]
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
SvhnTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
SvhnTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
SVHN
DATASET_CLASS
=
datasets
.
SVHN
...
@@ -2516,6 +2567,10 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -2516,6 +2567,10 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
breed_id
=
"-1"
breed_id
=
"-1"
return
(
image_id
,
class_id
,
species
,
breed_id
)
return
(
image_id
,
class_id
,
species
,
breed_id
)
def
test_transforms_v2_wrapper_spawn
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
datasets_utils
.
check_transforms_v2_wrapper_spawn
(
dataset
)
class
StanfordCarsTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
StanfordCarsTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
StanfordCars
DATASET_CLASS
=
datasets
.
StanfordCars
...
...
test/test_transforms_v2.py
View file @
054432d2
import
itertools
import
itertools
import
pathlib
import
pathlib
import
pickle
import
random
import
random
import
warnings
import
warnings
...
@@ -169,8 +170,11 @@ class TestSmoke:
...
@@ -169,8 +170,11 @@ class TestSmoke:
next
(
make_vanilla_tensor_images
()),
next
(
make_vanilla_tensor_images
()),
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"de_serialize"
,
[
lambda
t
:
t
,
lambda
t
:
pickle
.
loads
(
pickle
.
dumps
(
t
))])
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
def
test_common
(
self
,
transform
,
adapter
,
container_type
,
image_or_video
,
device
):
def
test_common
(
self
,
transform
,
adapter
,
container_type
,
image_or_video
,
de_serialize
,
device
):
transform
=
de_serialize
(
transform
)
canvas_size
=
F
.
get_size
(
image_or_video
)
canvas_size
=
F
.
get_size
(
image_or_video
)
input
=
dict
(
input
=
dict
(
image_or_video
=
image_or_video
,
image_or_video
=
image_or_video
,
...
...
test/test_transforms_v2_refactored.py
View file @
054432d2
...
@@ -2,6 +2,7 @@ import contextlib
...
@@ -2,6 +2,7 @@ import contextlib
import
decimal
import
decimal
import
inspect
import
inspect
import
math
import
math
import
pickle
import
re
import
re
from
pathlib
import
Path
from
pathlib
import
Path
from
unittest
import
mock
from
unittest
import
mock
...
@@ -247,6 +248,8 @@ def _check_transform_v1_compatibility(transform, input):
...
@@ -247,6 +248,8 @@ def _check_transform_v1_compatibility(transform, input):
def
check_transform
(
transform_cls
,
input
,
*
args
,
**
kwargs
):
def
check_transform
(
transform_cls
,
input
,
*
args
,
**
kwargs
):
transform
=
transform_cls
(
*
args
,
**
kwargs
)
transform
=
transform_cls
(
*
args
,
**
kwargs
)
pickle
.
loads
(
pickle
.
dumps
(
transform
))
output
=
transform
(
input
)
output
=
transform
(
input
)
assert
isinstance
(
output
,
type
(
input
))
assert
isinstance
(
output
,
type
(
input
))
...
...
torchvision/datapoints/_dataset_wrapper.py
View file @
054432d2
...
@@ -162,6 +162,7 @@ class VisionDatasetDatapointWrapper:
...
@@ -162,6 +162,7 @@ class VisionDatasetDatapointWrapper:
raise
TypeError
(
msg
)
raise
TypeError
(
msg
)
self
.
_dataset
=
dataset
self
.
_dataset
=
dataset
self
.
_target_keys
=
target_keys
self
.
_wrapper
=
wrapper_factory
(
dataset
,
target_keys
)
self
.
_wrapper
=
wrapper_factory
(
dataset
,
target_keys
)
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
...
@@ -197,6 +198,9 @@ class VisionDatasetDatapointWrapper:
...
@@ -197,6 +198,9 @@ class VisionDatasetDatapointWrapper:
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
_dataset
)
return
len
(
self
.
_dataset
)
def
__reduce__
(
self
):
return
wrap_dataset_for_transforms_v2
,
(
self
.
_dataset
,
self
.
_target_keys
)
def
raise_not_supported
(
description
):
def
raise_not_supported
(
description
):
raise
RuntimeError
(
raise
RuntimeError
(
...
...
torchvision/datasets/widerface.py
View file @
054432d2
...
@@ -137,13 +137,13 @@ class WIDERFace(VisionDataset):
...
@@ -137,13 +137,13 @@ class WIDERFace(VisionDataset):
{
{
"img_path"
:
img_path
,
"img_path"
:
img_path
,
"annotations"
:
{
"annotations"
:
{
"bbox"
:
labels_tensor
[:,
0
:
4
],
# x, y, width, height
"bbox"
:
labels_tensor
[:,
0
:
4
]
.
clone
()
,
# x, y, width, height
"blur"
:
labels_tensor
[:,
4
],
"blur"
:
labels_tensor
[:,
4
]
.
clone
()
,
"expression"
:
labels_tensor
[:,
5
],
"expression"
:
labels_tensor
[:,
5
]
.
clone
()
,
"illumination"
:
labels_tensor
[:,
6
],
"illumination"
:
labels_tensor
[:,
6
]
.
clone
()
,
"occlusion"
:
labels_tensor
[:,
7
],
"occlusion"
:
labels_tensor
[:,
7
]
.
clone
()
,
"pose"
:
labels_tensor
[:,
8
],
"pose"
:
labels_tensor
[:,
8
]
.
clone
()
,
"invalid"
:
labels_tensor
[:,
9
],
"invalid"
:
labels_tensor
[:,
9
]
.
clone
()
,
},
},
}
}
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment