Unverified Commit 31245cb8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

caltech (#2521)

parent e1c50d9c
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
from typing import Any, Callable, List, Optional, Union, Tuple
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_and_extract_archive, verify_str_arg from .utils import download_and_extract_archive, verify_str_arg
...@@ -29,8 +30,14 @@ class Caltech101(VisionDataset): ...@@ -29,8 +30,14 @@ class Caltech101(VisionDataset):
downloaded again. downloaded again.
""" """
def __init__(self, root, target_type="category", transform=None, def __init__(
target_transform=None, download=False): self,
root: str,
target_type: Union[List[str], str] = "category",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Caltech101, self).__init__(os.path.join(root, 'caltech101'), super(Caltech101, self).__init__(os.path.join(root, 'caltech101'),
transform=transform, transform=transform,
target_transform=target_transform) target_transform=target_transform)
...@@ -59,14 +66,14 @@ class Caltech101(VisionDataset): ...@@ -59,14 +66,14 @@ class Caltech101(VisionDataset):
"airplanes": "Airplanes_Side_2"} "airplanes": "Airplanes_Side_2"}
self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
self.index = [] self.index: List[int] = []
self.y = [] self.y = []
for (i, c) in enumerate(self.categories): for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c))) n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
self.index.extend(range(1, n + 1)) self.index.extend(range(1, n + 1))
self.y.extend(n * [i]) self.y.extend(n * [i])
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -81,7 +88,7 @@ class Caltech101(VisionDataset): ...@@ -81,7 +88,7 @@ class Caltech101(VisionDataset):
self.categories[self.y[index]], self.categories[self.y[index]],
"image_{:04d}.jpg".format(self.index[index]))) "image_{:04d}.jpg".format(self.index[index])))
target = [] target: Any = []
for t in self.target_type: for t in self.target_type:
if t == "category": if t == "category":
target.append(self.y[index]) target.append(self.y[index])
...@@ -101,14 +108,14 @@ class Caltech101(VisionDataset): ...@@ -101,14 +108,14 @@ class Caltech101(VisionDataset):
return img, target return img, target
def _check_integrity(self): def _check_integrity(self) -> bool:
# can be more robust and check hash of files # can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "101_ObjectCategories")) return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
def __len__(self): def __len__(self) -> int:
return len(self.index) return len(self.index)
def download(self): def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
...@@ -124,7 +131,7 @@ class Caltech101(VisionDataset): ...@@ -124,7 +131,7 @@ class Caltech101(VisionDataset):
filename="101_Annotations.tar", filename="101_Annotations.tar",
md5="6f83eeb1f24d99cab4eb377263132c91") md5="6f83eeb1f24d99cab4eb377263132c91")
def extra_repr(self): def extra_repr(self) -> str:
return "Target type: {target_type}".format(**self.__dict__) return "Target type: {target_type}".format(**self.__dict__)
...@@ -143,7 +150,13 @@ class Caltech256(VisionDataset): ...@@ -143,7 +150,13 @@ class Caltech256(VisionDataset):
downloaded again. downloaded again.
""" """
def __init__(self, root, transform=None, target_transform=None, download=False): def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Caltech256, self).__init__(os.path.join(root, 'caltech256'), super(Caltech256, self).__init__(os.path.join(root, 'caltech256'),
transform=transform, transform=transform,
target_transform=target_transform) target_transform=target_transform)
...@@ -157,14 +170,14 @@ class Caltech256(VisionDataset): ...@@ -157,14 +170,14 @@ class Caltech256(VisionDataset):
' You can use download=True to download it') ' You can use download=True to download it')
self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
self.index = [] self.index: List[int] = []
self.y = [] self.y = []
for (i, c) in enumerate(self.categories): for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c))) n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c)))
self.index.extend(range(1, n + 1)) self.index.extend(range(1, n + 1))
self.y.extend(n * [i]) self.y.extend(n * [i])
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -187,14 +200,14 @@ class Caltech256(VisionDataset): ...@@ -187,14 +200,14 @@ class Caltech256(VisionDataset):
return img, target return img, target
def _check_integrity(self): def _check_integrity(self) -> bool:
# can be more robust and check hash of files # can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "256_ObjectCategories")) return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
def __len__(self): def __len__(self) -> int:
return len(self.index) return len(self.index)
def download(self): def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment