Unverified Commit b78d98bb authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add example for v2 wrapping for custom datasets (#7514)

parent fc377d04
...@@ -20,6 +20,7 @@ import torchvision ...@@ -20,6 +20,7 @@ import torchvision
torchvision.disable_beta_transforms_warning() torchvision.disable_beta_transforms_warning()
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2 import functional as F
######################################################################################################################## ########################################################################################################################
...@@ -93,6 +94,68 @@ print(bounding_box) ...@@ -93,6 +94,68 @@ print(bounding_box)
# built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you # built-in datasets. Meaning, if your custom dataset subclasses from a built-in one and the output type is the same, you
# also don't have to wrap manually. # also don't have to wrap manually.
# #
# If you have a custom dataset, for example the ``PennFudanDataset`` from
# `this tutorial <https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html>`_, you have two options:
#
# 1. Perform the wrapping inside ``__getitem__``:
class PennFudanDataset(torch.utils.data.Dataset):
...
def __getitem__(self, item):
...
target["boxes"] = datapoints.BoundingBox(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
)
target["labels"] = labels
target["masks"] = datapoints.Mask(masks)
...
if self.transforms is not None:
img, target = self.transforms(img, target)
...
########################################################################################################################
# 2. Perform the wrapping inside a custom transformation at the beginning of your pipeline:
class WrapPennFudanDataset:
def __call__(self, img, target):
target["boxes"] = datapoints.BoundingBox(
target["boxes"],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
)
target["masks"] = datapoints.Mask(target["masks"])
return img, target
...
def get_transform(train):
transforms = []
transforms.append(WrapPennFudanDataset())
transforms.append(T.PILToTensor())
...
########################################################################################################################
# .. note::
#
# If both :class:`~torchvision.datapoints.BoundingBox`'es and :class:`~torchvision.datapoints.Mask`'s are included in
# the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or
# at least not wrapping the obsolete parts, can lead to a significant performance boost.
#
# For example, if you are using the ``PennFudanDataset`` for object detection, not wrapping the masks avoids
# transforming them over and over again in the pipeline just to ultimately ignoring them. In general, it would be
# even better to not load the masks at all, but this is not possible in this example, since the bounding boxes are
# generated from the masks.
#
# How do the datapoints behave inside a computation? # How do the datapoints behave inside a computation?
# -------------------------------------------------- # --------------------------------------------------
# #
...@@ -101,6 +164,7 @@ print(bounding_box) ...@@ -101,6 +164,7 @@ print(bounding_box)
# Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the # Since for most operations involving datapoints, it cannot be safely inferred whether the result should retain the
# datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below): # datapoint type, we choose to return a plain tensor instead of a datapoint (this might change, see note below):
assert isinstance(image, datapoints.Image) assert isinstance(image, datapoints.Image)
new_image = image + 0 new_image = image + 0
......
...@@ -124,7 +124,9 @@ class VisionDatasetDatapointWrapper(Dataset): ...@@ -124,7 +124,9 @@ class VisionDatasetDatapointWrapper(Dataset):
if not isinstance(dataset, datasets.VisionDataset): if not isinstance(dataset, datasets.VisionDataset):
raise TypeError( raise TypeError(
f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, " f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
f"but got a '{dataset_cls.__name__}' instead." f"but got a '{dataset_cls.__name__}' instead.\n"
f"For an example of how to perform the wrapping for custom datasets, see\n\n"
"https://pytorch.org/vision/main/auto_examples/plot_datapoints.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
) )
for cls in dataset_cls.mro(): for cls in dataset_cls.mro():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment