Unverified Commit 1c630960 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

use enums in prototype datasets for demux (#5189)

* use enums in prototype datasets for demux

* use enum for category generation

* revert enum usage for single use constants
parent 68f511eb
import enum
import io import io
import pathlib import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
...@@ -30,6 +31,12 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -30,6 +31,12 @@ from torchvision.prototype.datasets.utils._internal import (
from torchvision.prototype.features import Label from torchvision.prototype.features import Label
class DTDDemux(enum.IntEnum):
SPLIT = 0
JOINT_CATEGORIES = 1
IMAGES = 2
class DTD(Dataset): class DTD(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
...@@ -54,11 +61,11 @@ class DTD(Dataset): ...@@ -54,11 +61,11 @@ class DTD(Dataset):
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
if path.parent.name == "labels": if path.parent.name == "labels":
if path.name == "labels_joint_anno.txt": if path.name == "labels_joint_anno.txt":
return 1 return DTDDemux.JOINT_CATEGORIES
return 0 return DTDDemux.SPLIT
elif path.parents[1].name == "images": elif path.parents[1].name == "images":
return 2 return DTDDemux.IMAGES
else: else:
return None return None
...@@ -122,7 +129,7 @@ class DTD(Dataset): ...@@ -122,7 +129,7 @@ class DTD(Dataset):
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def _filter_images(self, data: Tuple[str, Any]) -> bool: def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == 2 return self._classify_archive(data) == DTDDemux.IMAGES
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
......
import enum
import functools import functools
import io import io
import pathlib import pathlib
...@@ -24,6 +25,11 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -24,6 +25,11 @@ from torchvision.prototype.datasets.utils._internal import (
from torchvision.prototype.features import Label from torchvision.prototype.features import Label
class OxfordIITPetDemux(enum.IntEnum):
SPLIT_AND_CLASSIFICATION = 0
SEGMENTATIONS = 1
class OxfordIITPet(Dataset): class OxfordIITPet(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
...@@ -51,8 +57,8 @@ class OxfordIITPet(Dataset): ...@@ -51,8 +57,8 @@ class OxfordIITPet(Dataset):
def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]: def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]:
return { return {
"annotations": 0, "annotations": OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION,
"trimaps": 1, "trimaps": OxfordIITPetDemux.SEGMENTATIONS,
}.get(pathlib.Path(data[0]).parent.name) }.get(pathlib.Path(data[0]).parent.name)
def _filter_images(self, data: Tuple[str, Any]) -> bool: def _filter_images(self, data: Tuple[str, Any]) -> bool:
...@@ -135,7 +141,7 @@ class OxfordIITPet(Dataset): ...@@ -135,7 +141,7 @@ class OxfordIITPet(Dataset):
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder)) return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_anns(data) == 0 return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.default_config config = self.default_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment