Unverified Commit 1c630960 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

use enums in prototype datasets for demux (#5189)

* use enums in prototype datasets for demux

* use enum for category generation

* revert enum usage for single use constants
parent 68f511eb
import enum
import io
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple
......@@ -30,6 +31,12 @@ from torchvision.prototype.datasets.utils._internal import (
from torchvision.prototype.features import Label
class DTDDemux(enum.IntEnum):
SPLIT = 0
JOINT_CATEGORIES = 1
IMAGES = 2
class DTD(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
......@@ -54,11 +61,11 @@ class DTD(Dataset):
path = pathlib.Path(data[0])
if path.parent.name == "labels":
if path.name == "labels_joint_anno.txt":
return 1
return DTDDemux.JOINT_CATEGORIES
return 0
return DTDDemux.SPLIT
elif path.parents[1].name == "images":
return 2
return DTDDemux.IMAGES
else:
return None
......@@ -122,7 +129,7 @@ class DTD(Dataset):
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == 2
return self._classify_archive(data) == DTDDemux.IMAGES
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
......
import enum
import functools
import io
import pathlib
......@@ -24,6 +25,11 @@ from torchvision.prototype.datasets.utils._internal import (
from torchvision.prototype.features import Label
class OxfordIITPetDemux(enum.IntEnum):
SPLIT_AND_CLASSIFICATION = 0
SEGMENTATIONS = 1
class OxfordIITPet(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
......@@ -51,8 +57,8 @@ class OxfordIITPet(Dataset):
def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]:
return {
"annotations": 0,
"trimaps": 1,
"annotations": OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION,
"trimaps": OxfordIITPetDemux.SEGMENTATIONS,
}.get(pathlib.Path(data[0]).parent.name)
def _filter_images(self, data: Tuple[str, Any]) -> bool:
......@@ -135,7 +141,7 @@ class OxfordIITPet(Dataset):
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_anns(data) == 0
return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION
def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.default_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment