Unverified Commit 10e658cd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix caltech and imagenet prototype datasets (#5032)

parent d98cccb0
...@@ -2,9 +2,10 @@ import io ...@@ -2,9 +2,10 @@ import io
import builtin_dataset_mocks import builtin_dataset_mocks
import pytest import pytest
import torch
from torch.utils.data.graph import traverse from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets, features from torchvision.prototype import datasets, transforms
from torchvision.prototype.datasets._api import DEFAULT_DECODER from torchvision.prototype.datasets._api import DEFAULT_DECODER
from torchvision.prototype.utils._internal import sequence_to_str from torchvision.prototype.utils._internal import sequence_to_str
...@@ -88,10 +89,17 @@ class TestCommon: ...@@ -88,10 +89,17 @@ class TestCommon:
) )
@dataset_parametrization(decoder=DEFAULT_DECODER) @dataset_parametrization(decoder=DEFAULT_DECODER)
def test_at_least_one_feature(self, dataset, mock_info): def test_no_vanilla_tensors(self, dataset, mock_info):
sample = next(iter(dataset)) vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if not any(isinstance(value, features.Feature) for value in sample.values()): if vanilla_tensors:
raise AssertionError("The sample contained no feature.") raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
)
@dataset_parametrization()
def test_transformable(self, dataset, mock_info):
next(iter(dataset.map(transforms.Identity())))
@dataset_parametrization() @dataset_parametrization()
def test_traversable(self, dataset, mock_info): def test_traversable(self, dataset, mock_info):
......
...@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils import ( ...@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils import (
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat
from torchvision.prototype.features import Label, BoundingBox from torchvision.prototype.features import Label, BoundingBox, Feature
class Caltech101(Dataset): class Caltech101(Dataset):
...@@ -98,7 +98,7 @@ class Caltech101(Dataset): ...@@ -98,7 +98,7 @@ class Caltech101(Dataset):
ann = read_mat(ann_buffer) ann = read_mat(ann_buffer)
bbox = BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy") bbox = BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy")
contour = torch.tensor(ann["obj_contour"].T) contour = Feature(ann["obj_contour"].T)
return dict( return dict(
category=category, category=category,
......
...@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
read_mat, read_mat,
) )
from torchvision.prototype.features import Label, DEFAULT from torchvision.prototype.features import Label
from torchvision.prototype.utils._internal import FrozenMapping from torchvision.prototype.utils._internal import FrozenMapping
...@@ -30,18 +30,6 @@ class ImageNetResource(ManualDownloadResource): ...@@ -30,18 +30,6 @@ class ImageNetResource(ManualDownloadResource):
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)
class ImageNetLabel(Label):
wnid: Optional[str]
@classmethod
def _parse_meta_data(
cls,
category: Optional[str] = DEFAULT, # type: ignore[assignment]
wnid: Optional[str] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
return dict(category=(category, None), wnid=(wnid, None))
class ImageNet(Dataset): class ImageNet(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
name = "imagenet" name = "imagenet"
...@@ -97,12 +85,12 @@ class ImageNet(Dataset): ...@@ -97,12 +85,12 @@ class ImageNet(Dataset):
_TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG") _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[ImageNetLabel, Tuple[str, io.IOBase]]: def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]:
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr] wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr]
category = self.wnid_to_category[wnid] category = self.wnid_to_category[wnid]
label = ImageNetLabel(self.categories.index(category), category=category, wnid=wnid) label_data = (Label(self.categories.index(category)), category, wnid)
return label, data return label_data, data
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG") _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
...@@ -112,28 +100,32 @@ class ImageNet(Dataset): ...@@ -112,28 +100,32 @@ class ImageNet(Dataset):
def _collate_val_data( def _collate_val_data(
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]] self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
) -> Tuple[ImageNetLabel, Tuple[str, io.IOBase]]: ) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]:
label_data, image_data = data label_data, image_data = data
_, label = label_data _, label = label_data
category = self.categories[label] category = self.categories[label]
wnid = self.category_to_wnid[category] wnid = self.category_to_wnid[category]
return ImageNetLabel(label, category=category, wnid=wnid), image_data return (Label(label), category, wnid), image_data
def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[None, Tuple[str, io.IOBase]]: def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[None, Tuple[str, io.IOBase]]:
return None, data return None, data
def _collate_and_decode_sample( def _collate_and_decode_sample(
self, self,
data: Tuple[Optional[ImageNetLabel], Tuple[str, io.IOBase]], data: Tuple[Optional[Tuple[Label, str, str]], Tuple[str, io.IOBase]],
*, *,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
label, (path, buffer) = data label_data, (path, buffer) = data
return dict(
sample = dict(
path=path, path=path,
image=decoder(buffer) if decoder else buffer, image=decoder(buffer) if decoder else buffer,
label=label,
) )
if label_data:
sample.update(dict(zip(("label", "category", "wnid"), label_data)))
return sample
def _make_datapipe( def _make_datapipe(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment