Unverified Commit 6db1569c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cifar (#2527)

parent 47f80acc
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import os.path import os.path
import numpy as np import numpy as np
import pickle import pickle
from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive from .utils import check_integrity, download_and_extract_archive
...@@ -46,8 +47,14 @@ class CIFAR10(VisionDataset): ...@@ -46,8 +47,14 @@ class CIFAR10(VisionDataset):
'md5': '5ff9c542aee3614f3951f8cda6e48888', 'md5': '5ff9c542aee3614f3951f8cda6e48888',
} }
def __init__(self, root, train=True, transform=None, target_transform=None, def __init__(
download=False): self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(CIFAR10, self).__init__(root, transform=transform, super(CIFAR10, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
...@@ -66,7 +73,7 @@ class CIFAR10(VisionDataset): ...@@ -66,7 +73,7 @@ class CIFAR10(VisionDataset):
else: else:
downloaded_list = self.test_list downloaded_list = self.test_list
self.data = [] self.data: Any = []
self.targets = [] self.targets = []
# now load the picked numpy arrays # now load the picked numpy arrays
...@@ -85,7 +92,7 @@ class CIFAR10(VisionDataset): ...@@ -85,7 +92,7 @@ class CIFAR10(VisionDataset):
self._load_meta() self._load_meta()
def _load_meta(self): def _load_meta(self) -> None:
path = os.path.join(self.root, self.base_folder, self.meta['filename']) path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not check_integrity(path, self.meta['md5']): if not check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' + raise RuntimeError('Dataset metadata file not found or corrupted.' +
...@@ -95,7 +102,7 @@ class CIFAR10(VisionDataset): ...@@ -95,7 +102,7 @@ class CIFAR10(VisionDataset):
self.classes = data[self.meta['key']] self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -117,10 +124,10 @@ class CIFAR10(VisionDataset): ...@@ -117,10 +124,10 @@ class CIFAR10(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.data) return len(self.data)
def _check_integrity(self): def _check_integrity(self) -> bool:
root = self.root root = self.root
for fentry in (self.train_list + self.test_list): for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1] filename, md5 = fentry[0], fentry[1]
...@@ -129,13 +136,13 @@ class CIFAR10(VisionDataset): ...@@ -129,13 +136,13 @@ class CIFAR10(VisionDataset):
return False return False
return True return True
def download(self): def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
def extra_repr(self): def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test") return "Split: {}".format("Train" if self.train is True else "Test")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment