Unverified Commit 7c9878a4 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove datapoints compatibility for prototype datasets (#7154)

parent a9d25721
......@@ -21,6 +21,7 @@ from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, datasets, transforms
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
......@@ -136,18 +137,21 @@ class TestCommon:
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_simple_tensors(self, dataset_mock, config):
def test_no_unaccompanied_simple_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))
simple_tensors = {
key
for key, value in next_consume(iter(dataset)).items()
if torchvision.prototype.transforms.utils.is_simple_tensor(value)
key for key, value in sample.items() if torchvision.prototype.transforms.utils.is_simple_tensor(value)
}
if simple_tensors:
if simple_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
):
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors."
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors, "
f"but didn't find any (encoded) image or video."
)
@parametrize_dataset_mocks(DATASET_MOCKS)
......
......@@ -29,26 +29,9 @@ class Datapoint(torch.Tensor):
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
# FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a
# a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be
# interpreted as images. We should decide if we want a public no-op datapoint like `GenericDatapoint` or make this
# one public again.
def __new__(
cls,
data: Any,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> Datapoint:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return tensor.as_subclass(Datapoint)
@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved,
# this method should be made abstract
# raise NotImplementedError
return tensor.as_subclass(cls)
raise NotImplementedError
_NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
......
......@@ -3,9 +3,10 @@ import re
from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np
import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
......@@ -115,7 +116,7 @@ class Caltech101(Dataset):
format="xyxy",
spatial_size=image.spatial_size,
),
contour=Datapoint(ann["obj_contour"].T),
contour=torch.as_tensor(ann["obj_contour"].T),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
......
......@@ -2,9 +2,9 @@ import csv
import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -149,7 +149,7 @@ class CelebA(Dataset):
spatial_size=image.spatial_size,
),
landmarks={
landmark: Datapoint((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
landmark: torch.tensor((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
for landmark in {key[:-2] for key in landmarks.keys()}
},
)
......
......@@ -15,7 +15,6 @@ from torchdata.datapipes.iter import (
UnBatcher,
)
from torchvision.prototype.datapoints import BoundingBox, Label, Mask
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -124,8 +123,8 @@ class Coco(Dataset):
]
)
),
areas=Datapoint([ann["area"] for ann in anns]),
crowds=Datapoint([ann["iscrowd"] for ann in anns], dtype=torch.bool),
areas=torch.as_tensor([ann["area"] for ann in anns]),
crowds=torch.as_tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool),
bounding_boxes=BoundingBox(
[ann["bbox"] for ann in anns],
format="xywh",
......
......@@ -3,6 +3,7 @@ import functools
import pathlib
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union
import torch
from torchdata.datapipes.iter import (
CSVDictParser,
CSVParser,
......@@ -15,7 +16,6 @@ from torchdata.datapipes.iter import (
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -162,7 +162,7 @@ class CUB200(Dataset):
format="xyxy",
spatial_size=spatial_size,
),
segmentation=Datapoint(content["seg"]),
segmentation=torch.as_tensor(content["seg"]),
)
def _prepare_sample(
......
......@@ -3,8 +3,8 @@ import re
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -92,8 +92,10 @@ class SBD(Dataset):
image=EncodedImage.from_file(image_buffer),
ann_path=ann_path,
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries=Datapoint(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])),
segmentation=Datapoint(anns["Segmentation"].item()),
boundaries=torch.as_tensor(
np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])
),
segmentation=torch.as_tensor(anns["Segmentation"].item()),
)
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment