Unverified Commit 7c9878a4 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove datapoints compatibility for prototype datasets (#7154)

parent a9d25721
...@@ -21,6 +21,7 @@ from torchdata.datapipes.iter import ShardingFilter, Shuffler ...@@ -21,6 +21,7 @@ from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper from torchdata.datapipes.utils import StreamWrapper
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, datasets, transforms from torchvision.prototype import datapoints, datasets, transforms
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
...@@ -136,18 +137,21 @@ class TestCommon: ...@@ -136,18 +137,21 @@ class TestCommon:
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:")) raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_simple_tensors(self, dataset_mock, config): def test_no_unaccompanied_simple_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))
simple_tensors = { simple_tensors = {
key key for key, value in sample.items() if torchvision.prototype.transforms.utils.is_simple_tensor(value)
for key, value in next_consume(iter(dataset)).items()
if torchvision.prototype.transforms.utils.is_simple_tensor(value)
} }
if simple_tensors:
if simple_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
):
raise AssertionError( raise AssertionError(
f"The values of key(s) " f"The values of key(s) "
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors." f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors, "
f"but didn't find any (encoded) image or video."
) )
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
......
...@@ -29,26 +29,9 @@ class Datapoint(torch.Tensor): ...@@ -29,26 +29,9 @@ class Datapoint(torch.Tensor):
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
# FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a
# a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be
# interpreted as images. We should decide if we want a public no-op datapoint like `GenericDatapoint` or make this
# one public again.
def __new__(
cls,
data: Any,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None,
) -> Datapoint:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return tensor.as_subclass(Datapoint)
@classmethod @classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved, raise NotImplementedError
# this method should be made abstract
# raise NotImplementedError
return tensor.as_subclass(cls)
_NO_WRAPPING_EXCEPTIONS = { _NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output), torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
......
...@@ -3,9 +3,10 @@ import re ...@@ -3,9 +3,10 @@ import re
from typing import Any, BinaryIO, Dict, List, Tuple, Union from typing import Any, BinaryIO, Dict, List, Tuple, Union
import numpy as np import numpy as np
import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
...@@ -115,7 +116,7 @@ class Caltech101(Dataset): ...@@ -115,7 +116,7 @@ class Caltech101(Dataset):
format="xyxy", format="xyxy",
spatial_size=image.spatial_size, spatial_size=image.spatial_size,
), ),
contour=Datapoint(ann["obj_contour"].T), contour=torch.as_tensor(ann["obj_contour"].T),
) )
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
......
...@@ -2,9 +2,9 @@ import csv ...@@ -2,9 +2,9 @@ import csv
import pathlib import pathlib
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -149,7 +149,7 @@ class CelebA(Dataset): ...@@ -149,7 +149,7 @@ class CelebA(Dataset):
spatial_size=image.spatial_size, spatial_size=image.spatial_size,
), ),
landmarks={ landmarks={
landmark: Datapoint((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) landmark: torch.tensor((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
for landmark in {key[:-2] for key in landmarks.keys()} for landmark in {key[:-2] for key in landmarks.keys()}
}, },
) )
......
...@@ -15,7 +15,6 @@ from torchdata.datapipes.iter import ( ...@@ -15,7 +15,6 @@ from torchdata.datapipes.iter import (
UnBatcher, UnBatcher,
) )
from torchvision.prototype.datapoints import BoundingBox, Label, Mask from torchvision.prototype.datapoints import BoundingBox, Label, Mask
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -124,8 +123,8 @@ class Coco(Dataset): ...@@ -124,8 +123,8 @@ class Coco(Dataset):
] ]
) )
), ),
areas=Datapoint([ann["area"] for ann in anns]), areas=torch.as_tensor([ann["area"] for ann in anns]),
crowds=Datapoint([ann["iscrowd"] for ann in anns], dtype=torch.bool), crowds=torch.as_tensor([ann["iscrowd"] for ann in anns], dtype=torch.bool),
bounding_boxes=BoundingBox( bounding_boxes=BoundingBox(
[ann["bbox"] for ann in anns], [ann["bbox"] for ann in anns],
format="xywh", format="xywh",
......
...@@ -3,6 +3,7 @@ import functools ...@@ -3,6 +3,7 @@ import functools
import pathlib import pathlib
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, Callable, Dict, List, Optional, Tuple, Union
import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
CSVDictParser, CSVDictParser,
CSVParser, CSVParser,
...@@ -15,7 +16,6 @@ from torchdata.datapipes.iter import ( ...@@ -15,7 +16,6 @@ from torchdata.datapipes.iter import (
) )
from torchdata.datapipes.map import IterToMapConverter from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -162,7 +162,7 @@ class CUB200(Dataset): ...@@ -162,7 +162,7 @@ class CUB200(Dataset):
format="xyxy", format="xyxy",
spatial_size=spatial_size, spatial_size=spatial_size,
), ),
segmentation=Datapoint(content["seg"]), segmentation=torch.as_tensor(content["seg"]),
) )
def _prepare_sample( def _prepare_sample(
......
...@@ -3,8 +3,8 @@ import re ...@@ -3,8 +3,8 @@ import re
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -92,8 +92,10 @@ class SBD(Dataset): ...@@ -92,8 +92,10 @@ class SBD(Dataset):
image=EncodedImage.from_file(image_buffer), image=EncodedImage.from_file(image_buffer),
ann_path=ann_path, ann_path=ann_path,
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch # the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries=Datapoint(np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])), boundaries=torch.as_tensor(
segmentation=Datapoint(anns["Segmentation"].item()), np.stack([raw_boundary.toarray() for raw_boundary in anns["Boundaries"].item()])
),
segmentation=torch.as_tensor(anns["Segmentation"].item()),
) )
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment