Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
27b84916
Unverified
Commit
27b84916
authored
Apr 06, 2023
by
Philip Meier
Committed by
GitHub
Apr 06, 2023
Browse files
only return small set of targets by default from dataset wrapper (#7488)
parent
ce653d8b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
223 additions
and
69 deletions
+223
-69
gallery/plot_transforms_v2_e2e.py
gallery/plot_transforms_v2_e2e.py
+3
-2
test/datasets_utils.py
test/datasets_utils.py
+15
-3
test/test_datasets.py
test/test_datasets.py
+3
-1
torchvision/datapoints/_dataset_wrapper.py
torchvision/datapoints/_dataset_wrapper.py
+202
-63
No files found.
gallery/plot_transforms_v2_e2e.py
View file @
27b84916
...
@@ -75,7 +75,8 @@ print(type(target), type(target[0]), list(target[0].keys()))
...
@@ -75,7 +75,8 @@ print(type(target), type(target[0]), list(target[0].keys()))
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It
# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It
# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding
# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding
# ``torchvision.datapoints``.
# ``torchvision.datapoints``. By default, it only returns ``"boxes"`` and ``"labels"`` to avoid transforming unnecessary
# items down the line, but you can pass the ``target_type`` parameter for fine-grained control.
dataset
=
datasets
.
wrap_dataset_for_transforms_v2
(
dataset
)
dataset
=
datasets
.
wrap_dataset_for_transforms_v2
(
dataset
)
...
@@ -83,7 +84,7 @@ sample = dataset[0]
...
@@ -83,7 +84,7 @@ sample = dataset[0]
image
,
target
=
sample
image
,
target
=
sample
print
(
type
(
image
))
print
(
type
(
image
))
print
(
type
(
target
),
list
(
target
.
keys
()))
print
(
type
(
target
),
list
(
target
.
keys
()))
print
(
type
(
target
[
"boxes"
]),
type
(
target
[
"masks"
]),
type
(
target
[
"labels"
]))
print
(
type
(
target
[
"boxes"
]),
type
(
target
[
"labels"
]))
########################################################################################################################
########################################################################################################################
# As baseline, let's have a look at a sample without transformations:
# As baseline, let's have a look at a sample without transformations:
...
...
test/datasets_utils.py
View file @
27b84916
...
@@ -572,8 +572,20 @@ class DatasetTestCase(unittest.TestCase):
...
@@ -572,8 +572,20 @@ class DatasetTestCase(unittest.TestCase):
try
:
try
:
with
self
.
create_dataset
(
config
)
as
(
dataset
,
_
):
with
self
.
create_dataset
(
config
)
as
(
dataset
,
_
):
wrapped_dataset
=
wrap_dataset_for_transforms_v2
(
dataset
)
for
target_keys
in
[
None
,
"all"
]:
if
target_keys
is
not
None
and
self
.
DATASET_CLASS
not
in
{
torchvision
.
datasets
.
CocoDetection
,
torchvision
.
datasets
.
VOCDetection
,
torchvision
.
datasets
.
Kitti
,
torchvision
.
datasets
.
WIDERFace
,
}:
with
self
.
assertRaisesRegex
(
ValueError
,
"`target_keys` is currently only supported for"
):
wrap_dataset_for_transforms_v2
(
dataset
,
target_keys
=
target_keys
)
continue
wrapped_dataset
=
wrap_dataset_for_transforms_v2
(
dataset
,
target_keys
=
target_keys
)
wrapped_sample
=
wrapped_dataset
[
0
]
wrapped_sample
=
wrapped_dataset
[
0
]
assert
tree_any
(
lambda
item
:
isinstance
(
item
,
(
Datapoint
,
PIL
.
Image
.
Image
)),
wrapped_sample
)
assert
tree_any
(
lambda
item
:
isinstance
(
item
,
(
Datapoint
,
PIL
.
Image
.
Image
)),
wrapped_sample
)
except
TypeError
as
error
:
except
TypeError
as
error
:
msg
=
f
"No wrapper exists for dataset class
{
type
(
dataset
).
__name__
}
"
msg
=
f
"No wrapper exists for dataset class
{
type
(
dataset
).
__name__
}
"
...
...
test/test_datasets.py
View file @
27b84916
...
@@ -771,6 +771,8 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -771,6 +771,8 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
bbox
=
torch
.
rand
(
4
).
tolist
(),
bbox
=
torch
.
rand
(
4
).
tolist
(),
segmentation
=
[
torch
.
rand
(
8
).
tolist
()],
segmentation
=
[
torch
.
rand
(
8
).
tolist
()],
category_id
=
int
(
torch
.
randint
(
91
,
())),
category_id
=
int
(
torch
.
randint
(
91
,
())),
area
=
float
(
torch
.
rand
(
1
)),
iscrowd
=
int
(
torch
.
randint
(
2
,
size
=
(
1
,))),
)
)
)
)
annotion_id
+=
1
annotion_id
+=
1
...
@@ -3336,7 +3338,7 @@ class TestDatasetWrapper:
...
@@ -3336,7 +3338,7 @@ class TestDatasetWrapper:
mocker
.
patch
.
dict
(
mocker
.
patch
.
dict
(
datapoints
.
_dataset_wrapper
.
WRAPPER_FACTORIES
,
datapoints
.
_dataset_wrapper
.
WRAPPER_FACTORIES
,
clear
=
False
,
clear
=
False
,
values
=
{
datasets
.
FakeData
:
lambda
dataset
:
lambda
idx
,
sample
:
sentinel
},
values
=
{
datasets
.
FakeData
:
lambda
dataset
,
target_keys
:
lambda
idx
,
sample
:
sentinel
},
)
)
class
MyFakeData
(
datasets
.
FakeData
):
class
MyFakeData
(
datasets
.
FakeData
):
...
...
torchvision/datapoints/_dataset_wrapper.py
View file @
27b84916
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
collections.abc
import
contextlib
import
contextlib
from
collections
import
defaultdict
from
collections
import
defaultdict
...
@@ -14,7 +16,7 @@ from torchvision.transforms.v2 import functional as F
...
@@ -14,7 +16,7 @@ from torchvision.transforms.v2 import functional as F
__all__
=
[
"wrap_dataset_for_transforms_v2"
]
__all__
=
[
"wrap_dataset_for_transforms_v2"
]
def
wrap_dataset_for_transforms_v2
(
dataset
):
def
wrap_dataset_for_transforms_v2
(
dataset
,
target_keys
=
None
):
"""[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.
"""[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.
.. v2betastatus:: wrap_dataset_for_transforms_v2 function
.. v2betastatus:: wrap_dataset_for_transforms_v2 function
...
@@ -36,15 +38,17 @@ def wrap_dataset_for_transforms_v2(dataset):
...
@@ -36,15 +38,17 @@ def wrap_dataset_for_transforms_v2(dataset):
* :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
* :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``.
``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``.
The original keys are preserved.
The original keys are preserved. If ``target_keys`` is ommitted, returns only the values for the ``"boxes"``
and ``"labels"``.
* :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
* :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are
the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are
preserved.
preserved.
If ``target_keys`` is ommitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
* :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint.
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint.
* :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dictsthe wrapper returns a dict
* :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a
of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
in the corresponding ``torchvision.datapoints``. The original keys are preserved.
in the corresponding ``torchvision.datapoints``. The original keys are preserved. If ``target_keys`` is
ommitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
* :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
:class:`~torchvision.datapoints.Mask` datapoint.
:class:`~torchvision.datapoints.Mask` datapoint.
* :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
* :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
...
@@ -61,13 +65,13 @@ def wrap_dataset_for_transforms_v2(dataset):
...
@@ -61,13 +65,13 @@ def wrap_dataset_for_transforms_v2(dataset):
Segmentation datasets
Segmentation datasets
Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of
Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation`
,
return a two-tuple of
:class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
:class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item).
segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item).
Video classification datasets
Video classification datasets
Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple containing a
Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics`
,
return a three-tuple containing a
:class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
:class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
:class:`~torchvision.datapoints.Video` while leaving the other items as is.
:class:`~torchvision.datapoints.Video` while leaving the other items as is.
...
@@ -78,8 +82,23 @@ def wrap_dataset_for_transforms_v2(dataset):
...
@@ -78,8 +82,23 @@ def wrap_dataset_for_transforms_v2(dataset):
Args:
Args:
dataset: the dataset instance to wrap for compatibility with transforms v2.
dataset: the dataset instance to wrap for compatibility with transforms v2.
target_keys: Target keys to return in case the target is a dictionary. If ``None`` (default), selected keys are
specific to the dataset. If ``"all"``, returns the full target. Can also be a collection of strings for
fine grained access. Currently only supported for :class:`~torchvision.datasets.CocoDetection`,
:class:`~torchvision.datasets.VOCDetection`, :class:`~torchvision.datasets.Kitti`, and
:class:`~torchvision.datasets.WIDERFace`. See above for details.
"""
"""
return
VisionDatasetDatapointWrapper
(
dataset
)
if
not
(
target_keys
is
None
or
target_keys
==
"all"
or
(
isinstance
(
target_keys
,
collections
.
abc
.
Collection
)
and
all
(
isinstance
(
key
,
str
)
for
key
in
target_keys
))
):
raise
ValueError
(
f
"`target_keys` can be None, 'all', or a collection of strings denoting the keys to be returned, "
f
"but got
{
target_keys
}
"
)
return
VisionDatasetDatapointWrapper
(
dataset
,
target_keys
)
class
WrapperFactories
(
dict
):
class
WrapperFactories
(
dict
):
...
@@ -99,7 +118,7 @@ WRAPPER_FACTORIES = WrapperFactories()
...
@@ -99,7 +118,7 @@ WRAPPER_FACTORIES = WrapperFactories()
class
VisionDatasetDatapointWrapper
(
Dataset
):
class
VisionDatasetDatapointWrapper
(
Dataset
):
def
__init__
(
self
,
dataset
):
def
__init__
(
self
,
dataset
,
target_keys
):
dataset_cls
=
type
(
dataset
)
dataset_cls
=
type
(
dataset
)
if
not
isinstance
(
dataset
,
datasets
.
VisionDataset
):
if
not
isinstance
(
dataset
,
datasets
.
VisionDataset
):
...
@@ -111,6 +130,16 @@ class VisionDatasetDatapointWrapper(Dataset):
...
@@ -111,6 +130,16 @@ class VisionDatasetDatapointWrapper(Dataset):
for
cls
in
dataset_cls
.
mro
():
for
cls
in
dataset_cls
.
mro
():
if
cls
in
WRAPPER_FACTORIES
:
if
cls
in
WRAPPER_FACTORIES
:
wrapper_factory
=
WRAPPER_FACTORIES
[
cls
]
wrapper_factory
=
WRAPPER_FACTORIES
[
cls
]
if
target_keys
is
not
None
and
cls
not
in
{
datasets
.
CocoDetection
,
datasets
.
VOCDetection
,
datasets
.
Kitti
,
datasets
.
WIDERFace
,
}:
raise
ValueError
(
f
"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, "
f
"and `WIDERFace`, but got
{
cls
.
__name__
}
."
)
break
break
elif
cls
is
datasets
.
VisionDataset
:
elif
cls
is
datasets
.
VisionDataset
:
# TODO: If we have documentation on how to do that, put a link in the error message.
# TODO: If we have documentation on how to do that, put a link in the error message.
...
@@ -123,7 +152,7 @@ class VisionDatasetDatapointWrapper(Dataset):
...
@@ -123,7 +152,7 @@ class VisionDatasetDatapointWrapper(Dataset):
raise
TypeError
(
msg
)
raise
TypeError
(
msg
)
self
.
_dataset
=
dataset
self
.
_dataset
=
dataset
self
.
_wrapper
=
wrapper_factory
(
dataset
)
self
.
_wrapper
=
wrapper_factory
(
dataset
,
target_keys
)
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
# Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
# Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
...
@@ -170,7 +199,7 @@ def identity(item):
...
@@ -170,7 +199,7 @@ def identity(item):
return
item
return
item
def
identity_wrapper_factory
(
dataset
):
def
identity_wrapper_factory
(
dataset
,
target_keys
):
def
wrapper
(
idx
,
sample
):
def
wrapper
(
idx
,
sample
):
return
sample
return
sample
...
@@ -181,6 +210,20 @@ def pil_image_to_mask(pil_image):
...
@@ -181,6 +210,20 @@ def pil_image_to_mask(pil_image):
return
datapoints
.
Mask
(
pil_image
)
return
datapoints
.
Mask
(
pil_image
)
def
parse_target_keys
(
target_keys
,
*
,
available
,
default
):
if
target_keys
is
None
:
target_keys
=
default
if
target_keys
==
"all"
:
target_keys
=
available
else
:
target_keys
=
set
(
target_keys
)
extra
=
target_keys
-
available
if
extra
:
raise
ValueError
(
f
"Target keys
{
sorted
(
extra
)
}
are not available"
)
return
target_keys
def
list_of_dicts_to_dict_of_lists
(
list_of_dicts
):
def
list_of_dicts_to_dict_of_lists
(
list_of_dicts
):
dict_of_lists
=
defaultdict
(
list
)
dict_of_lists
=
defaultdict
(
list
)
for
dct
in
list_of_dicts
:
for
dct
in
list_of_dicts
:
...
@@ -203,8 +246,8 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
...
@@ -203,8 +246,8 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
return
wrapped_target
return
wrapped_target
def
classification_wrapper_factory
(
dataset
):
def
classification_wrapper_factory
(
dataset
,
target_keys
):
return
identity_wrapper_factory
(
dataset
)
return
identity_wrapper_factory
(
dataset
,
target_keys
)
for
dataset_cls
in
[
for
dataset_cls
in
[
...
@@ -221,7 +264,7 @@ for dataset_cls in [
...
@@ -221,7 +264,7 @@ for dataset_cls in [
WRAPPER_FACTORIES
.
register
(
dataset_cls
)(
classification_wrapper_factory
)
WRAPPER_FACTORIES
.
register
(
dataset_cls
)(
classification_wrapper_factory
)
def
segmentation_wrapper_factory
(
dataset
):
def
segmentation_wrapper_factory
(
dataset
,
target_keys
):
def
wrapper
(
idx
,
sample
):
def
wrapper
(
idx
,
sample
):
image
,
mask
=
sample
image
,
mask
=
sample
return
image
,
pil_image_to_mask
(
mask
)
return
image
,
pil_image_to_mask
(
mask
)
...
@@ -235,7 +278,7 @@ for dataset_cls in [
...
@@ -235,7 +278,7 @@ for dataset_cls in [
WRAPPER_FACTORIES
.
register
(
dataset_cls
)(
segmentation_wrapper_factory
)
WRAPPER_FACTORIES
.
register
(
dataset_cls
)(
segmentation_wrapper_factory
)
def
video_classification_wrapper_factory
(
dataset
):
def
video_classification_wrapper_factory
(
dataset
,
target_keys
):
if
dataset
.
video_clips
.
output_format
==
"THWC"
:
if
dataset
.
video_clips
.
output_format
==
"THWC"
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"
{
type
(
dataset
).
__name__
}
with `output_format='THWC'` is not supported by this wrapper, "
f
"
{
type
(
dataset
).
__name__
}
with `output_format='THWC'` is not supported by this wrapper, "
...
@@ -261,15 +304,33 @@ for dataset_cls in [
...
@@ -261,15 +304,33 @@ for dataset_cls in [
@
WRAPPER_FACTORIES
.
register
(
datasets
.
Caltech101
)
@
WRAPPER_FACTORIES
.
register
(
datasets
.
Caltech101
)
def
caltech101_wrapper_factory
(
dataset
):
def
caltech101_wrapper_factory
(
dataset
,
target_keys
):
if
"annotation"
in
dataset
.
target_type
:
if
"annotation"
in
dataset
.
target_type
:
raise_not_supported
(
"Caltech101 dataset with `target_type=['annotation', ...]`"
)
raise_not_supported
(
"Caltech101 dataset with `target_type=['annotation', ...]`"
)
return
classification_wrapper_factory
(
dataset
)
return
classification_wrapper_factory
(
dataset
,
target_keys
)
@
WRAPPER_FACTORIES
.
register
(
datasets
.
CocoDetection
)
@
WRAPPER_FACTORIES
.
register
(
datasets
.
CocoDetection
)
def
coco_dectection_wrapper_factory
(
dataset
):
def
coco_dectection_wrapper_factory
(
dataset
,
target_keys
):
target_keys
=
parse_target_keys
(
target_keys
,
available
=
{
# native
"segmentation"
,
"area"
,
"iscrowd"
,
"image_id"
,
"bbox"
,
"category_id"
,
# added by the wrapper
"boxes"
,
"masks"
,
"labels"
,
},
default
=
{
"boxes"
,
"labels"
},
)
def
segmentation_to_mask
(
segmentation
,
*
,
spatial_size
):
def
segmentation_to_mask
(
segmentation
,
*
,
spatial_size
):
from
pycocotools
import
mask
from
pycocotools
import
mask
...
@@ -288,12 +349,16 @@ def coco_dectection_wrapper_factory(dataset):
...
@@ -288,12 +349,16 @@ def coco_dectection_wrapper_factory(dataset):
if
not
target
:
if
not
target
:
return
image
,
dict
(
image_id
=
image_id
)
return
image
,
dict
(
image_id
=
image_id
)
spatial_size
=
tuple
(
F
.
get_spatial_size
(
image
))
batched_target
=
list_of_dicts_to_dict_of_lists
(
target
)
batched_target
=
list_of_dicts_to_dict_of_lists
(
target
)
target
=
{}
batched_target
[
"image_id"
]
=
image_id
if
"image_id"
in
target_keys
:
target
[
"image_id"
]
=
image_id
spatial_size
=
tuple
(
F
.
get_spatial_size
(
image
))
if
"boxes"
in
target_keys
:
batched_
target
[
"boxes"
]
=
F
.
convert_format_bounding_box
(
target
[
"boxes"
]
=
F
.
convert_format_bounding_box
(
datapoints
.
BoundingBox
(
datapoints
.
BoundingBox
(
batched_target
[
"bbox"
],
batched_target
[
"bbox"
],
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
...
@@ -301,7 +366,9 @@ def coco_dectection_wrapper_factory(dataset):
...
@@ -301,7 +366,9 @@ def coco_dectection_wrapper_factory(dataset):
),
),
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
)
)
batched_target
[
"masks"
]
=
datapoints
.
Mask
(
if
"masks"
in
target_keys
:
target
[
"masks"
]
=
datapoints
.
Mask
(
torch
.
stack
(
torch
.
stack
(
[
[
segmentation_to_mask
(
segmentation
,
spatial_size
=
spatial_size
)
segmentation_to_mask
(
segmentation
,
spatial_size
=
spatial_size
)
...
@@ -309,9 +376,14 @@ def coco_dectection_wrapper_factory(dataset):
...
@@ -309,9 +376,14 @@ def coco_dectection_wrapper_factory(dataset):
]
]
),
),
)
)
batched_target
[
"labels"
]
=
torch
.
tensor
(
batched_target
[
"category_id"
])
return
image
,
batched_target
if
"labels"
in
target_keys
:
target
[
"labels"
]
=
torch
.
tensor
(
batched_target
[
"category_id"
])
for
target_key
in
target_keys
-
{
"image_id"
,
"boxes"
,
"masks"
,
"labels"
}:
target
[
target_key
]
=
batched_target
[
target_key
]
return
image
,
target
return
wrapper
return
wrapper
...
@@ -346,12 +418,28 @@ VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC
...
@@ -346,12 +418,28 @@ VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC
@
WRAPPER_FACTORIES
.
register
(
datasets
.
VOCDetection
)
@
WRAPPER_FACTORIES
.
register
(
datasets
.
VOCDetection
)
def
voc_detection_wrapper_factory
(
dataset
):
def
voc_detection_wrapper_factory
(
dataset
,
target_keys
):
target_keys
=
parse_target_keys
(
target_keys
,
available
=
{
# native
"annotation"
,
# added by the wrapper
"boxes"
,
"labels"
,
},
default
=
{
"boxes"
,
"labels"
},
)
def
wrapper
(
idx
,
sample
):
def
wrapper
(
idx
,
sample
):
image
,
target
=
sample
image
,
target
=
sample
batched_instances
=
list_of_dicts_to_dict_of_lists
(
target
[
"annotation"
][
"object"
])
batched_instances
=
list_of_dicts_to_dict_of_lists
(
target
[
"annotation"
][
"object"
])
if
"annotation"
not
in
target_keys
:
target
=
{}
if
"boxes"
in
target_keys
:
target
[
"boxes"
]
=
datapoints
.
BoundingBox
(
target
[
"boxes"
]
=
datapoints
.
BoundingBox
(
[
[
[
int
(
bndbox
[
part
])
for
part
in
(
"xmin"
,
"ymin"
,
"xmax"
,
"ymax"
)]
[
int
(
bndbox
[
part
])
for
part
in
(
"xmin"
,
"ymin"
,
"xmax"
,
"ymax"
)]
...
@@ -360,6 +448,8 @@ def voc_detection_wrapper_factory(dataset):
...
@@ -360,6 +448,8 @@ def voc_detection_wrapper_factory(dataset):
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
image
.
height
,
image
.
width
),
spatial_size
=
(
image
.
height
,
image
.
width
),
)
)
if
"labels"
in
target_keys
:
target
[
"labels"
]
=
torch
.
tensor
(
target
[
"labels"
]
=
torch
.
tensor
(
[
VOC_DETECTION_CATEGORY_TO_IDX
[
category
]
for
category
in
batched_instances
[
"name"
]]
[
VOC_DETECTION_CATEGORY_TO_IDX
[
category
]
for
category
in
batched_instances
[
"name"
]]
)
)
...
@@ -370,15 +460,15 @@ def voc_detection_wrapper_factory(dataset):
...
@@ -370,15 +460,15 @@ def voc_detection_wrapper_factory(dataset):
@
WRAPPER_FACTORIES
.
register
(
datasets
.
SBDataset
)
@
WRAPPER_FACTORIES
.
register
(
datasets
.
SBDataset
)
def
sbd_wrapper
(
dataset
):
def
sbd_wrapper
(
dataset
,
target_keys
):
if
dataset
.
mode
==
"boundaries"
:
if
dataset
.
mode
==
"boundaries"
:
raise_not_supported
(
"SBDataset with mode='boundaries'"
)
raise_not_supported
(
"SBDataset with mode='boundaries'"
)
return
segmentation_wrapper_factory
(
dataset
)
return
segmentation_wrapper_factory
(
dataset
,
target_keys
)
@
WRAPPER_FACTORIES
.
register
(
datasets
.
CelebA
)
@
WRAPPER_FACTORIES
.
register
(
datasets
.
CelebA
)
def
celeba_wrapper_factory
(
dataset
):
def
celeba_wrapper_factory
(
dataset
,
target_keys
):
if
any
(
target_type
in
dataset
.
target_type
for
target_type
in
[
"attr"
,
"landmarks"
]):
if
any
(
target_type
in
dataset
.
target_type
for
target_type
in
[
"attr"
,
"landmarks"
]):
raise_not_supported
(
"`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`"
)
raise_not_supported
(
"`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`"
)
...
@@ -410,17 +500,47 @@ KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES)))
...
@@ -410,17 +500,47 @@ KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES)))
@
WRAPPER_FACTORIES
.
register
(
datasets
.
Kitti
)
@
WRAPPER_FACTORIES
.
register
(
datasets
.
Kitti
)
def
kitti_wrapper_factory
(
dataset
):
def
kitti_wrapper_factory
(
dataset
,
target_keys
):
target_keys
=
parse_target_keys
(
target_keys
,
available
=
{
# native
"type"
,
"truncated"
,
"occluded"
,
"alpha"
,
"bbox"
,
"dimensions"
,
"location"
,
"rotation_y"
,
# added by the wrapper
"boxes"
,
"labels"
,
},
default
=
{
"boxes"
,
"labels"
},
)
def
wrapper
(
idx
,
sample
):
def
wrapper
(
idx
,
sample
):
image
,
target
=
sample
image
,
target
=
sample
if
target
is
not
None
:
if
target
is
None
:
target
=
list_of_dicts_to_dict_of_lists
(
target
)
return
image
,
target
batched_target
=
list_of_dicts_to_dict_of_lists
(
target
)
target
=
{}
if
"boxes"
in
target_keys
:
target
[
"boxes"
]
=
datapoints
.
BoundingBox
(
target
[
"boxes"
]
=
datapoints
.
BoundingBox
(
target
[
"bbox"
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
image
.
height
,
image
.
width
)
batched_target
[
"bbox"
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
image
.
height
,
image
.
width
),
)
)
target
[
"labels"
]
=
torch
.
tensor
([
KITTI_CATEGORY_TO_IDX
[
category
]
for
category
in
target
[
"type"
]])
if
"labels"
in
target_keys
:
target
[
"labels"
]
=
torch
.
tensor
([
KITTI_CATEGORY_TO_IDX
[
category
]
for
category
in
batched_target
[
"type"
]])
for
target_key
in
target_keys
-
{
"boxes"
,
"labels"
}:
target
[
target_key
]
=
batched_target
[
target_key
]
return
image
,
target
return
image
,
target
...
@@ -428,7 +548,7 @@ def kitti_wrapper_factory(dataset):
...
@@ -428,7 +548,7 @@ def kitti_wrapper_factory(dataset):
@
WRAPPER_FACTORIES
.
register
(
datasets
.
OxfordIIITPet
)
@
WRAPPER_FACTORIES
.
register
(
datasets
.
OxfordIIITPet
)
def
oxford_iiit_pet_wrapper_factor
(
dataset
):
def
oxford_iiit_pet_wrapper_factor
(
dataset
,
target_keys
):
def
wrapper
(
idx
,
sample
):
def
wrapper
(
idx
,
sample
):
image
,
target
=
sample
image
,
target
=
sample
...
@@ -447,7 +567,7 @@ def oxford_iiit_pet_wrapper_factor(dataset):
...
@@ -447,7 +567,7 @@ def oxford_iiit_pet_wrapper_factor(dataset):
@
WRAPPER_FACTORIES
.
register
(
datasets
.
Cityscapes
)
@
WRAPPER_FACTORIES
.
register
(
datasets
.
Cityscapes
)
def
cityscapes_wrapper_factory
(
dataset
):
def
cityscapes_wrapper_factory
(
dataset
,
target_keys
):
if
any
(
target_type
in
dataset
.
target_type
for
target_type
in
[
"polygon"
,
"color"
]):
if
any
(
target_type
in
dataset
.
target_type
for
target_type
in
[
"polygon"
,
"color"
]):
raise_not_supported
(
"`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`"
)
raise_not_supported
(
"`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`"
)
...
@@ -482,11 +602,30 @@ def cityscapes_wrapper_factory(dataset):
...
@@ -482,11 +602,30 @@ def cityscapes_wrapper_factory(dataset):
@
WRAPPER_FACTORIES
.
register
(
datasets
.
WIDERFace
)
@
WRAPPER_FACTORIES
.
register
(
datasets
.
WIDERFace
)
def
widerface_wrapper
(
dataset
):
def
widerface_wrapper
(
dataset
,
target_keys
):
target_keys
=
parse_target_keys
(
target_keys
,
available
=
{
"bbox"
,
"blur"
,
"expression"
,
"illumination"
,
"occlusion"
,
"pose"
,
"invalid"
,
},
default
=
"all"
,
)
def
wrapper
(
idx
,
sample
):
def
wrapper
(
idx
,
sample
):
image
,
target
=
sample
image
,
target
=
sample
if
target
is
not
None
:
if
target
is
None
:
return
image
,
target
target
=
{
key
:
target
[
key
]
for
key
in
target_keys
}
if
"bbox"
in
target_keys
:
target
[
"bbox"
]
=
F
.
convert_format_bounding_box
(
target
[
"bbox"
]
=
F
.
convert_format_bounding_box
(
datapoints
.
BoundingBox
(
datapoints
.
BoundingBox
(
target
[
"bbox"
],
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
spatial_size
=
(
image
.
height
,
image
.
width
)
target
[
"bbox"
],
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
spatial_size
=
(
image
.
height
,
image
.
width
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment