Unverified Commit 10e658cd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix caltech and imagenet prototype datasets (#5032)

parent d98cccb0
......@@ -2,9 +2,10 @@ import io
import builtin_dataset_mocks
import pytest
import torch
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets, features
from torchvision.prototype import datasets, transforms
from torchvision.prototype.datasets._api import DEFAULT_DECODER
from torchvision.prototype.utils._internal import sequence_to_str
......@@ -88,10 +89,17 @@ class TestCommon:
)
@dataset_parametrization(decoder=DEFAULT_DECODER)
def test_at_least_one_feature(self, dataset, mock_info):
sample = next(iter(dataset))
if not any(isinstance(value, features.Feature) for value in sample.values()):
raise AssertionError("The sample contained no feature.")
def test_no_vanilla_tensors(self, dataset, mock_info):
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if vanilla_tensors:
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
)
@dataset_parametrization()
def test_transformable(self, dataset, mock_info):
next(iter(dataset.map(transforms.Identity())))
@dataset_parametrization()
def test_traversable(self, dataset, mock_info):
......
......@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils import (
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat
from torchvision.prototype.features import Label, BoundingBox
from torchvision.prototype.features import Label, BoundingBox, Feature
class Caltech101(Dataset):
......@@ -98,7 +98,7 @@ class Caltech101(Dataset):
ann = read_mat(ann_buffer)
bbox = BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy")
contour = torch.tensor(ann["obj_contour"].T)
contour = Feature(ann["obj_contour"].T)
return dict(
category=category,
......
......@@ -21,7 +21,7 @@ from torchvision.prototype.datasets.utils._internal import (
getitem,
read_mat,
)
from torchvision.prototype.features import Label, DEFAULT
from torchvision.prototype.features import Label
from torchvision.prototype.utils._internal import FrozenMapping
......@@ -30,18 +30,6 @@ class ImageNetResource(ManualDownloadResource):
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)
class ImageNetLabel(Label):
wnid: Optional[str]
@classmethod
def _parse_meta_data(
cls,
category: Optional[str] = DEFAULT, # type: ignore[assignment]
wnid: Optional[str] = DEFAULT, # type: ignore[assignment]
) -> Dict[str, Tuple[Any, Any]]:
return dict(category=(category, None), wnid=(wnid, None))
class ImageNet(Dataset):
def _make_info(self) -> DatasetInfo:
name = "imagenet"
......@@ -97,12 +85,12 @@ class ImageNet(Dataset):
_TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[ImageNetLabel, Tuple[str, io.IOBase]]:
def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]:
path = pathlib.Path(data[0])
wnid = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("wnid") # type: ignore[union-attr]
category = self.wnid_to_category[wnid]
label = ImageNetLabel(self.categories.index(category), category=category, wnid=wnid)
return label, data
label_data = (Label(self.categories.index(category)), category, wnid)
return label_data, data
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
......@@ -112,28 +100,32 @@ class ImageNet(Dataset):
def _collate_val_data(
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
) -> Tuple[ImageNetLabel, Tuple[str, io.IOBase]]:
) -> Tuple[Tuple[Label, str, str], Tuple[str, io.IOBase]]:
label_data, image_data = data
_, label = label_data
category = self.categories[label]
wnid = self.category_to_wnid[category]
return ImageNetLabel(label, category=category, wnid=wnid), image_data
return (Label(label), category, wnid), image_data
def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[None, Tuple[str, io.IOBase]]:
return None, data
def _collate_and_decode_sample(
self,
data: Tuple[Optional[ImageNetLabel], Tuple[str, io.IOBase]],
data: Tuple[Optional[Tuple[Label, str, str]], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
label, (path, buffer) = data
return dict(
label_data, (path, buffer) = data
sample = dict(
path=path,
image=decoder(buffer) if decoder else buffer,
label=label,
)
if label_data:
sample.update(dict(zip(("label", "category", "wnid"), label_data)))
return sample
def _make_datapipe(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment