Unverified Commit 27b84916 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

only return small set of targets by default from dataset wrapper (#7488)

parent ce653d8b
...@@ -75,7 +75,8 @@ print(type(target), type(target[0]), list(target[0].keys())) ...@@ -75,7 +75,8 @@ print(type(target), type(target[0]), list(target[0].keys()))
# :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For # :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
# :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It # :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It
# also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding # also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding
# ``torchvision.datapoints``. # ``torchvision.datapoints``. By default, it only returns ``"boxes"`` and ``"labels"`` to avoid transforming unnecessary
# items down the line, but you can pass the ``target_type`` parameter for fine-grained control.
dataset = datasets.wrap_dataset_for_transforms_v2(dataset) dataset = datasets.wrap_dataset_for_transforms_v2(dataset)
...@@ -83,7 +84,7 @@ sample = dataset[0] ...@@ -83,7 +84,7 @@ sample = dataset[0]
image, target = sample image, target = sample
print(type(image)) print(type(image))
print(type(target), list(target.keys())) print(type(target), list(target.keys()))
print(type(target["boxes"]), type(target["masks"]), type(target["labels"])) print(type(target["boxes"]), type(target["labels"]))
######################################################################################################################## ########################################################################################################################
# As baseline, let's have a look at a sample without transformations: # As baseline, let's have a look at a sample without transformations:
......
...@@ -572,8 +572,20 @@ class DatasetTestCase(unittest.TestCase): ...@@ -572,8 +572,20 @@ class DatasetTestCase(unittest.TestCase):
try: try:
with self.create_dataset(config) as (dataset, _): with self.create_dataset(config) as (dataset, _):
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) for target_keys in [None, "all"]:
if target_keys is not None and self.DATASET_CLASS not in {
torchvision.datasets.CocoDetection,
torchvision.datasets.VOCDetection,
torchvision.datasets.Kitti,
torchvision.datasets.WIDERFace,
}:
with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
continue
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
wrapped_sample = wrapped_dataset[0] wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error: except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}" msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
......
...@@ -771,6 +771,8 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -771,6 +771,8 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
bbox=torch.rand(4).tolist(), bbox=torch.rand(4).tolist(),
segmentation=[torch.rand(8).tolist()], segmentation=[torch.rand(8).tolist()],
category_id=int(torch.randint(91, ())), category_id=int(torch.randint(91, ())),
area=float(torch.rand(1)),
iscrowd=int(torch.randint(2, size=(1,))),
) )
) )
annotion_id += 1 annotion_id += 1
...@@ -3336,7 +3338,7 @@ class TestDatasetWrapper: ...@@ -3336,7 +3338,7 @@ class TestDatasetWrapper:
mocker.patch.dict( mocker.patch.dict(
datapoints._dataset_wrapper.WRAPPER_FACTORIES, datapoints._dataset_wrapper.WRAPPER_FACTORIES,
clear=False, clear=False,
values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel}, values={datasets.FakeData: lambda dataset, target_keys: lambda idx, sample: sentinel},
) )
class MyFakeData(datasets.FakeData): class MyFakeData(datasets.FakeData):
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
import collections.abc
import contextlib import contextlib
from collections import defaultdict from collections import defaultdict
...@@ -14,7 +16,7 @@ from torchvision.transforms.v2 import functional as F ...@@ -14,7 +16,7 @@ from torchvision.transforms.v2 import functional as F
__all__ = ["wrap_dataset_for_transforms_v2"] __all__ = ["wrap_dataset_for_transforms_v2"]
def wrap_dataset_for_transforms_v2(dataset): def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
"""[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`. """[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`.
.. v2betastatus:: wrap_dataset_for_transforms_v2 function .. v2betastatus:: wrap_dataset_for_transforms_v2 function
...@@ -36,15 +38,17 @@ def wrap_dataset_for_transforms_v2(dataset): ...@@ -36,15 +38,17 @@ def wrap_dataset_for_transforms_v2(dataset):
* :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format), returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``. ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``.
The original keys are preserved. The original keys are preserved. If ``target_keys`` is ommitted, returns only the values for the ``"boxes"``
and ``"labels"``.
* :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are
preserved. preserved. If ``target_keys`` is ommitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY`` * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint.
* :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dictsthe wrapper returns a dict * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a
of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
in the corresponding ``torchvision.datapoints``. The original keys are preserved. in the corresponding ``torchvision.datapoints``. The original keys are preserved. If ``target_keys`` is
ommitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
:class:`~torchvision.datapoints.Mask` datapoint. :class:`~torchvision.datapoints.Mask` datapoint.
* :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
...@@ -61,13 +65,13 @@ def wrap_dataset_for_transforms_v2(dataset): ...@@ -61,13 +65,13 @@ def wrap_dataset_for_transforms_v2(dataset):
Segmentation datasets Segmentation datasets
Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation`, return a two-tuple of
:class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item). segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item).
Video classification datasets Video classification datasets
Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple containing a Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics`, return a three-tuple containing a
:class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
:class:`~torchvision.datapoints.Video` while leaving the other items as is. :class:`~torchvision.datapoints.Video` while leaving the other items as is.
...@@ -78,8 +82,23 @@ def wrap_dataset_for_transforms_v2(dataset): ...@@ -78,8 +82,23 @@ def wrap_dataset_for_transforms_v2(dataset):
Args: Args:
dataset: the dataset instance to wrap for compatibility with transforms v2. dataset: the dataset instance to wrap for compatibility with transforms v2.
target_keys: Target keys to return in case the target is a dictionary. If ``None`` (default), selected keys are
specific to the dataset. If ``"all"``, returns the full target. Can also be a collection of strings for
fine grained access. Currently only supported for :class:`~torchvision.datasets.CocoDetection`,
:class:`~torchvision.datasets.VOCDetection`, :class:`~torchvision.datasets.Kitti`, and
:class:`~torchvision.datasets.WIDERFace`. See above for details.
""" """
return VisionDatasetDatapointWrapper(dataset) if not (
target_keys is None
or target_keys == "all"
or (isinstance(target_keys, collections.abc.Collection) and all(isinstance(key, str) for key in target_keys))
):
raise ValueError(
f"`target_keys` can be None, 'all', or a collection of strings denoting the keys to be returned, "
f"but got {target_keys}"
)
return VisionDatasetDatapointWrapper(dataset, target_keys)
class WrapperFactories(dict): class WrapperFactories(dict):
...@@ -99,7 +118,7 @@ WRAPPER_FACTORIES = WrapperFactories() ...@@ -99,7 +118,7 @@ WRAPPER_FACTORIES = WrapperFactories()
class VisionDatasetDatapointWrapper(Dataset): class VisionDatasetDatapointWrapper(Dataset):
def __init__(self, dataset): def __init__(self, dataset, target_keys):
dataset_cls = type(dataset) dataset_cls = type(dataset)
if not isinstance(dataset, datasets.VisionDataset): if not isinstance(dataset, datasets.VisionDataset):
...@@ -111,6 +130,16 @@ class VisionDatasetDatapointWrapper(Dataset): ...@@ -111,6 +130,16 @@ class VisionDatasetDatapointWrapper(Dataset):
for cls in dataset_cls.mro(): for cls in dataset_cls.mro():
if cls in WRAPPER_FACTORIES: if cls in WRAPPER_FACTORIES:
wrapper_factory = WRAPPER_FACTORIES[cls] wrapper_factory = WRAPPER_FACTORIES[cls]
if target_keys is not None and cls not in {
datasets.CocoDetection,
datasets.VOCDetection,
datasets.Kitti,
datasets.WIDERFace,
}:
raise ValueError(
f"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, "
f"and `WIDERFace`, but got {cls.__name__}."
)
break break
elif cls is datasets.VisionDataset: elif cls is datasets.VisionDataset:
# TODO: If we have documentation on how to do that, put a link in the error message. # TODO: If we have documentation on how to do that, put a link in the error message.
...@@ -123,7 +152,7 @@ class VisionDatasetDatapointWrapper(Dataset): ...@@ -123,7 +152,7 @@ class VisionDatasetDatapointWrapper(Dataset):
raise TypeError(msg) raise TypeError(msg)
self._dataset = dataset self._dataset = dataset
self._wrapper = wrapper_factory(dataset) self._wrapper = wrapper_factory(dataset, target_keys)
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them. # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
# Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
...@@ -170,7 +199,7 @@ def identity(item): ...@@ -170,7 +199,7 @@ def identity(item):
return item return item
def identity_wrapper_factory(dataset): def identity_wrapper_factory(dataset, target_keys):
def wrapper(idx, sample): def wrapper(idx, sample):
return sample return sample
...@@ -181,6 +210,20 @@ def pil_image_to_mask(pil_image): ...@@ -181,6 +210,20 @@ def pil_image_to_mask(pil_image):
return datapoints.Mask(pil_image) return datapoints.Mask(pil_image)
def parse_target_keys(target_keys, *, available, default):
if target_keys is None:
target_keys = default
if target_keys == "all":
target_keys = available
else:
target_keys = set(target_keys)
extra = target_keys - available
if extra:
raise ValueError(f"Target keys {sorted(extra)} are not available")
return target_keys
def list_of_dicts_to_dict_of_lists(list_of_dicts): def list_of_dicts_to_dict_of_lists(list_of_dicts):
dict_of_lists = defaultdict(list) dict_of_lists = defaultdict(list)
for dct in list_of_dicts: for dct in list_of_dicts:
...@@ -203,8 +246,8 @@ def wrap_target_by_type(target, *, target_types, type_wrappers): ...@@ -203,8 +246,8 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
return wrapped_target return wrapped_target
def classification_wrapper_factory(dataset): def classification_wrapper_factory(dataset, target_keys):
return identity_wrapper_factory(dataset) return identity_wrapper_factory(dataset, target_keys)
for dataset_cls in [ for dataset_cls in [
...@@ -221,7 +264,7 @@ for dataset_cls in [ ...@@ -221,7 +264,7 @@ for dataset_cls in [
WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory) WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)
def segmentation_wrapper_factory(dataset): def segmentation_wrapper_factory(dataset, target_keys):
def wrapper(idx, sample): def wrapper(idx, sample):
image, mask = sample image, mask = sample
return image, pil_image_to_mask(mask) return image, pil_image_to_mask(mask)
...@@ -235,7 +278,7 @@ for dataset_cls in [ ...@@ -235,7 +278,7 @@ for dataset_cls in [
WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory) WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory)
def video_classification_wrapper_factory(dataset): def video_classification_wrapper_factory(dataset, target_keys):
if dataset.video_clips.output_format == "THWC": if dataset.video_clips.output_format == "THWC":
raise RuntimeError( raise RuntimeError(
f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, " f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, "
...@@ -261,15 +304,33 @@ for dataset_cls in [ ...@@ -261,15 +304,33 @@ for dataset_cls in [
@WRAPPER_FACTORIES.register(datasets.Caltech101) @WRAPPER_FACTORIES.register(datasets.Caltech101)
def caltech101_wrapper_factory(dataset): def caltech101_wrapper_factory(dataset, target_keys):
if "annotation" in dataset.target_type: if "annotation" in dataset.target_type:
raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`") raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")
return classification_wrapper_factory(dataset) return classification_wrapper_factory(dataset, target_keys)
@WRAPPER_FACTORIES.register(datasets.CocoDetection) @WRAPPER_FACTORIES.register(datasets.CocoDetection)
def coco_dectection_wrapper_factory(dataset): def coco_dectection_wrapper_factory(dataset, target_keys):
target_keys = parse_target_keys(
target_keys,
available={
# native
"segmentation",
"area",
"iscrowd",
"image_id",
"bbox",
"category_id",
# added by the wrapper
"boxes",
"masks",
"labels",
},
default={"boxes", "labels"},
)
def segmentation_to_mask(segmentation, *, spatial_size): def segmentation_to_mask(segmentation, *, spatial_size):
from pycocotools import mask from pycocotools import mask
...@@ -288,12 +349,16 @@ def coco_dectection_wrapper_factory(dataset): ...@@ -288,12 +349,16 @@ def coco_dectection_wrapper_factory(dataset):
if not target: if not target:
return image, dict(image_id=image_id) return image, dict(image_id=image_id)
spatial_size = tuple(F.get_spatial_size(image))
batched_target = list_of_dicts_to_dict_of_lists(target) batched_target = list_of_dicts_to_dict_of_lists(target)
target = {}
batched_target["image_id"] = image_id if "image_id" in target_keys:
target["image_id"] = image_id
spatial_size = tuple(F.get_spatial_size(image)) if "boxes" in target_keys:
batched_target["boxes"] = F.convert_format_bounding_box( target["boxes"] = F.convert_format_bounding_box(
datapoints.BoundingBox( datapoints.BoundingBox(
batched_target["bbox"], batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYWH, format=datapoints.BoundingBoxFormat.XYWH,
...@@ -301,7 +366,9 @@ def coco_dectection_wrapper_factory(dataset): ...@@ -301,7 +366,9 @@ def coco_dectection_wrapper_factory(dataset):
), ),
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=datapoints.BoundingBoxFormat.XYXY,
) )
batched_target["masks"] = datapoints.Mask(
if "masks" in target_keys:
target["masks"] = datapoints.Mask(
torch.stack( torch.stack(
[ [
segmentation_to_mask(segmentation, spatial_size=spatial_size) segmentation_to_mask(segmentation, spatial_size=spatial_size)
...@@ -309,9 +376,14 @@ def coco_dectection_wrapper_factory(dataset): ...@@ -309,9 +376,14 @@ def coco_dectection_wrapper_factory(dataset):
] ]
), ),
) )
batched_target["labels"] = torch.tensor(batched_target["category_id"])
return image, batched_target if "labels" in target_keys:
target["labels"] = torch.tensor(batched_target["category_id"])
for target_key in target_keys - {"image_id", "boxes", "masks", "labels"}:
target[target_key] = batched_target[target_key]
return image, target
return wrapper return wrapper
...@@ -346,12 +418,28 @@ VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC ...@@ -346,12 +418,28 @@ VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC
@WRAPPER_FACTORIES.register(datasets.VOCDetection) @WRAPPER_FACTORIES.register(datasets.VOCDetection)
def voc_detection_wrapper_factory(dataset): def voc_detection_wrapper_factory(dataset, target_keys):
target_keys = parse_target_keys(
target_keys,
available={
# native
"annotation",
# added by the wrapper
"boxes",
"labels",
},
default={"boxes", "labels"},
)
def wrapper(idx, sample): def wrapper(idx, sample):
image, target = sample image, target = sample
batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"]) batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
if "annotation" not in target_keys:
target = {}
if "boxes" in target_keys:
target["boxes"] = datapoints.BoundingBox( target["boxes"] = datapoints.BoundingBox(
[ [
[int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
...@@ -360,6 +448,8 @@ def voc_detection_wrapper_factory(dataset): ...@@ -360,6 +448,8 @@ def voc_detection_wrapper_factory(dataset):
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(image.height, image.width), spatial_size=(image.height, image.width),
) )
if "labels" in target_keys:
target["labels"] = torch.tensor( target["labels"] = torch.tensor(
[VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]] [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]]
) )
...@@ -370,15 +460,15 @@ def voc_detection_wrapper_factory(dataset): ...@@ -370,15 +460,15 @@ def voc_detection_wrapper_factory(dataset):
@WRAPPER_FACTORIES.register(datasets.SBDataset) @WRAPPER_FACTORIES.register(datasets.SBDataset)
def sbd_wrapper(dataset): def sbd_wrapper(dataset, target_keys):
if dataset.mode == "boundaries": if dataset.mode == "boundaries":
raise_not_supported("SBDataset with mode='boundaries'") raise_not_supported("SBDataset with mode='boundaries'")
return segmentation_wrapper_factory(dataset) return segmentation_wrapper_factory(dataset, target_keys)
@WRAPPER_FACTORIES.register(datasets.CelebA) @WRAPPER_FACTORIES.register(datasets.CelebA)
def celeba_wrapper_factory(dataset): def celeba_wrapper_factory(dataset, target_keys):
if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]): if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`") raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")
...@@ -410,17 +500,47 @@ KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES))) ...@@ -410,17 +500,47 @@ KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES)))
@WRAPPER_FACTORIES.register(datasets.Kitti) @WRAPPER_FACTORIES.register(datasets.Kitti)
def kitti_wrapper_factory(dataset): def kitti_wrapper_factory(dataset, target_keys):
target_keys = parse_target_keys(
target_keys,
available={
# native
"type",
"truncated",
"occluded",
"alpha",
"bbox",
"dimensions",
"location",
"rotation_y",
# added by the wrapper
"boxes",
"labels",
},
default={"boxes", "labels"},
)
def wrapper(idx, sample): def wrapper(idx, sample):
image, target = sample image, target = sample
if target is not None: if target is None:
target = list_of_dicts_to_dict_of_lists(target) return image, target
batched_target = list_of_dicts_to_dict_of_lists(target)
target = {}
if "boxes" in target_keys:
target["boxes"] = datapoints.BoundingBox( target["boxes"] = datapoints.BoundingBox(
target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width) batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(image.height, image.width),
) )
target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in target["type"]])
if "labels" in target_keys:
target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in batched_target["type"]])
for target_key in target_keys - {"boxes", "labels"}:
target[target_key] = batched_target[target_key]
return image, target return image, target
...@@ -428,7 +548,7 @@ def kitti_wrapper_factory(dataset): ...@@ -428,7 +548,7 @@ def kitti_wrapper_factory(dataset):
@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet) @WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
def oxford_iiit_pet_wrapper_factor(dataset): def oxford_iiit_pet_wrapper_factor(dataset, target_keys):
def wrapper(idx, sample): def wrapper(idx, sample):
image, target = sample image, target = sample
...@@ -447,7 +567,7 @@ def oxford_iiit_pet_wrapper_factor(dataset): ...@@ -447,7 +567,7 @@ def oxford_iiit_pet_wrapper_factor(dataset):
@WRAPPER_FACTORIES.register(datasets.Cityscapes) @WRAPPER_FACTORIES.register(datasets.Cityscapes)
def cityscapes_wrapper_factory(dataset): def cityscapes_wrapper_factory(dataset, target_keys):
if any(target_type in dataset.target_type for target_type in ["polygon", "color"]): if any(target_type in dataset.target_type for target_type in ["polygon", "color"]):
raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`") raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`")
...@@ -482,11 +602,30 @@ def cityscapes_wrapper_factory(dataset): ...@@ -482,11 +602,30 @@ def cityscapes_wrapper_factory(dataset):
@WRAPPER_FACTORIES.register(datasets.WIDERFace) @WRAPPER_FACTORIES.register(datasets.WIDERFace)
def widerface_wrapper(dataset): def widerface_wrapper(dataset, target_keys):
target_keys = parse_target_keys(
target_keys,
available={
"bbox",
"blur",
"expression",
"illumination",
"occlusion",
"pose",
"invalid",
},
default="all",
)
def wrapper(idx, sample): def wrapper(idx, sample):
image, target = sample image, target = sample
if target is not None: if target is None:
return image, target
target = {key: target[key] for key in target_keys}
if "bbox" in target_keys:
target["bbox"] = F.convert_format_bounding_box( target["bbox"] = F.convert_format_bounding_box(
datapoints.BoundingBox( datapoints.BoundingBox(
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment