Unverified Commit 6db1569c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cifar (#2527)

parent 47f80acc
......@@ -3,6 +3,7 @@ import os
import os.path
import numpy as np
import pickle
from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive
......@@ -46,8 +47,14 @@ class CIFAR10(VisionDataset):
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(CIFAR10, self).__init__(root, transform=transform,
target_transform=target_transform)
......@@ -66,7 +73,7 @@ class CIFAR10(VisionDataset):
else:
downloaded_list = self.test_list
self.data = []
self.data: Any = []
self.targets = []
# now load the picked numpy arrays
......@@ -85,7 +92,7 @@ class CIFAR10(VisionDataset):
self._load_meta()
def _load_meta(self):
def _load_meta(self) -> None:
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
......@@ -95,7 +102,7 @@ class CIFAR10(VisionDataset):
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -117,10 +124,10 @@ class CIFAR10(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return len(self.data)
def _check_integrity(self):
def _check_integrity(self) -> bool:
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
......@@ -129,13 +136,13 @@ class CIFAR10(VisionDataset):
return False
return True
def download(self):
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
def extra_repr(self):
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment