"vscode:/vscode.git/clone" did not exist on "aaaecbc9030acbab35ec55db4becde9ca8b765b4"
Unverified Commit 31245cb8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

caltech (#2521)

parent e1c50d9c
from PIL import Image
import os
import os.path
from typing import Any, Callable, List, Optional, Union, Tuple
from .vision import VisionDataset
from .utils import download_and_extract_archive, verify_str_arg
......@@ -29,8 +30,14 @@ class Caltech101(VisionDataset):
downloaded again.
"""
def __init__(self, root, target_type="category", transform=None,
target_transform=None, download=False):
def __init__(
self,
root: str,
target_type: Union[List[str], str] = "category",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Caltech101, self).__init__(os.path.join(root, 'caltech101'),
transform=transform,
target_transform=target_transform)
......@@ -59,14 +66,14 @@ class Caltech101(VisionDataset):
"airplanes": "Airplanes_Side_2"}
self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
self.index = []
self.index: List[int] = []
self.y = []
for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
self.index.extend(range(1, n + 1))
self.y.extend(n * [i])
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -81,7 +88,7 @@ class Caltech101(VisionDataset):
self.categories[self.y[index]],
"image_{:04d}.jpg".format(self.index[index])))
target = []
target: Any = []
for t in self.target_type:
if t == "category":
target.append(self.y[index])
......@@ -101,14 +108,14 @@ class Caltech101(VisionDataset):
return img, target
def _check_integrity(self):
def _check_integrity(self) -> bool:
# can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
def __len__(self):
def __len__(self) -> int:
return len(self.index)
def download(self):
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
......@@ -124,7 +131,7 @@ class Caltech101(VisionDataset):
filename="101_Annotations.tar",
md5="6f83eeb1f24d99cab4eb377263132c91")
def extra_repr(self):
def extra_repr(self) -> str:
return "Target type: {target_type}".format(**self.__dict__)
......@@ -143,7 +150,13 @@ class Caltech256(VisionDataset):
downloaded again.
"""
def __init__(self, root, transform=None, target_transform=None, download=False):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Caltech256, self).__init__(os.path.join(root, 'caltech256'),
transform=transform,
target_transform=target_transform)
......@@ -157,14 +170,14 @@ class Caltech256(VisionDataset):
' You can use download=True to download it')
self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
self.index = []
self.index: List[int] = []
self.y = []
for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c)))
self.index.extend(range(1, n + 1))
self.y.extend(n * [i])
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -187,14 +200,14 @@ class Caltech256(VisionDataset):
return img, target
def _check_integrity(self):
def _check_integrity(self) -> bool:
# can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
def __len__(self):
def __len__(self) -> int:
return len(self.index)
def download(self):
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment