Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
from .vision import VisionDataset
from PIL import Image
import os import os
import os.path import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from PIL import Image
from .vision import VisionDataset
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension. """Checks if a file is an allowed extension.
...@@ -132,16 +132,15 @@ class DatasetFolder(VisionDataset): ...@@ -132,16 +132,15 @@ class DatasetFolder(VisionDataset):
""" """
def __init__( def __init__(
self, self,
root: str, root: str,
loader: Callable[[str], Any], loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None, extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None: ) -> None:
super(DatasetFolder, self).__init__(root, transform=transform, super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root) classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
...@@ -186,9 +185,7 @@ class DatasetFolder(VisionDataset): ...@@ -186,9 +185,7 @@ class DatasetFolder(VisionDataset):
# prevent potential bug since make_dataset() would use the class_to_idx logic of the # prevent potential bug since make_dataset() would use the class_to_idx logic of the
# find_classes() function, instead of using that of the find_classes() method, which # find_classes() function, instead of using that of the find_classes() method, which
# is potentially overridden and thus could have a different logic. # is potentially overridden and thus could have a different logic.
raise ValueError( raise ValueError("The class_to_idx parameter cannot be None.")
"The class_to_idx parameter cannot be None."
)
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
...@@ -241,19 +238,20 @@ class DatasetFolder(VisionDataset): ...@@ -241,19 +238,20 @@ class DatasetFolder(VisionDataset):
return len(self.samples) return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
def pil_loader(path: str) -> Image.Image: def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f: with open(path, "rb") as f:
img = Image.open(f) img = Image.open(f)
return img.convert('RGB') return img.convert("RGB")
# TODO: specify the return type # TODO: specify the return type
def accimage_loader(path: str) -> Any: def accimage_loader(path: str) -> Any:
import accimage import accimage
try: try:
return accimage.Image(path) return accimage.Image(path)
except IOError: except IOError:
...@@ -263,7 +261,8 @@ def accimage_loader(path: str) -> Any: ...@@ -263,7 +261,8 @@ def accimage_loader(path: str) -> Any:
def default_loader(path: str) -> Any: def default_loader(path: str) -> Any:
from torchvision import get_image_backend from torchvision import get_image_backend
if get_image_backend() == 'accimage':
if get_image_backend() == "accimage":
return accimage_loader(path) return accimage_loader(path)
else: else:
return pil_loader(path) return pil_loader(path)
...@@ -300,15 +299,19 @@ class ImageFolder(DatasetFolder): ...@@ -300,15 +299,19 @@ class ImageFolder(DatasetFolder):
""" """
def __init__( def __init__(
self, self,
root: str, root: str,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader, loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
): ):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, super(ImageFolder, self).__init__(
transform=transform, root,
target_transform=target_transform, loader,
is_valid_file=is_valid_file) IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples self.imgs = self.samples
import glob import glob
import os import os
from typing import Optional, Callable, Tuple, Dict, Any, List from typing import Optional, Callable, Tuple, Dict, Any, List
from torch import Tensor from torch import Tensor
from .folder import find_classes, make_dataset from .folder import find_classes, make_dataset
...@@ -49,7 +50,7 @@ class HMDB51(VisionDataset): ...@@ -49,7 +50,7 @@ class HMDB51(VisionDataset):
data_url = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar" data_url = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
splits = { splits = {
"url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar", "url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
"md5": "15e67781e70dcfbdce2d7dbb9b3344b5" "md5": "15e67781e70dcfbdce2d7dbb9b3344b5",
} }
TRAIN_TAG = 1 TRAIN_TAG = 1
TEST_TAG = 2 TEST_TAG = 2
...@@ -75,7 +76,7 @@ class HMDB51(VisionDataset): ...@@ -75,7 +76,7 @@ class HMDB51(VisionDataset):
if fold not in (1, 2, 3): if fold not in (1, 2, 3):
raise ValueError("fold should be between 1 and 3, got {}".format(fold)) raise ValueError("fold should be between 1 and 3, got {}".format(fold))
extensions = ('avi',) extensions = ("avi",)
self.classes, class_to_idx = find_classes(self.root) self.classes, class_to_idx = find_classes(self.root)
self.samples = make_dataset( self.samples = make_dataset(
self.root, self.root,
......
import warnings
from contextlib import contextmanager
import os import os
import shutil import shutil
import tempfile import tempfile
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Iterator, Optional, Tuple from typing import Any, Dict, List, Iterator, Optional, Tuple
import torch import torch
from .folder import ImageFolder from .folder import ImageFolder
from .utils import check_integrity, extract_archive, verify_str_arg from .utils import check_integrity, extract_archive, verify_str_arg
ARCHIVE_META = { ARCHIVE_META = {
'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'), "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'), "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf') "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
} }
META_FILE = "meta.bin" META_FILE = "meta.bin"
...@@ -38,15 +40,16 @@ class ImageNet(ImageFolder): ...@@ -38,15 +40,16 @@ class ImageNet(ImageFolder):
targets (list): The class_index value for each image in the dataset targets (list): The class_index value for each image in the dataset
""" """
def __init__(self, root: str, split: str = 'train', download: Optional[str] = None, **kwargs: Any) -> None: def __init__(self, root: str, split: str = "train", download: Optional[str] = None, **kwargs: Any) -> None:
if download is True: if download is True:
msg = ("The dataset is no longer publicly accessible. You need to " msg = (
"download the archives externally and place them in the root " "The dataset is no longer publicly accessible. You need to "
"directory.") "download the archives externally and place them in the root "
"directory."
)
raise RuntimeError(msg) raise RuntimeError(msg)
elif download is False: elif download is False:
msg = ("The use of the download flag is deprecated, since the dataset " msg = "The use of the download flag is deprecated, since the dataset " "is no longer publicly accessible."
"is no longer publicly accessible.")
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
root = self.root = os.path.expanduser(root) root = self.root = os.path.expanduser(root)
...@@ -61,18 +64,16 @@ class ImageNet(ImageFolder): ...@@ -61,18 +64,16 @@ class ImageNet(ImageFolder):
self.wnids = self.classes self.wnids = self.classes
self.wnid_to_idx = self.class_to_idx self.wnid_to_idx = self.class_to_idx
self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
self.class_to_idx = {cls: idx self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss}
for idx, clss in enumerate(self.classes)
for cls in clss}
def parse_archives(self) -> None: def parse_archives(self) -> None:
if not check_integrity(os.path.join(self.root, META_FILE)): if not check_integrity(os.path.join(self.root, META_FILE)):
parse_devkit_archive(self.root) parse_devkit_archive(self.root)
if not os.path.isdir(self.split_folder): if not os.path.isdir(self.split_folder):
if self.split == 'train': if self.split == "train":
parse_train_archive(self.root) parse_train_archive(self.root)
elif self.split == 'val': elif self.split == "val":
parse_val_archive(self.root) parse_val_archive(self.root)
@property @property
...@@ -91,15 +92,19 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str ...@@ -91,15 +92,19 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str
if check_integrity(file): if check_integrity(file):
return torch.load(file) return torch.load(file)
else: else:
msg = ("The meta file {} is not present in the root directory or is corrupted. " msg = (
"This file is automatically created by the ImageNet dataset.") "The meta file {} is not present in the root directory or is corrupted. "
"This file is automatically created by the ImageNet dataset."
)
raise RuntimeError(msg.format(file, root)) raise RuntimeError(msg.format(file, root))
def _verify_archive(root: str, file: str, md5: str) -> None: def _verify_archive(root: str, file: str, md5: str) -> None:
if not check_integrity(os.path.join(root, file), md5): if not check_integrity(os.path.join(root, file), md5):
msg = ("The archive {} is not present in the root directory or is corrupted. " msg = (
"You need to download it externally and place it in {}.") "The archive {} is not present in the root directory or is corrupted. "
"You need to download it externally and place it in {}."
)
raise RuntimeError(msg.format(file, root)) raise RuntimeError(msg.format(file, root))
...@@ -116,20 +121,18 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None: ...@@ -116,20 +121,18 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]: def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]:
metafile = os.path.join(devkit_root, "data", "meta.mat") metafile = os.path.join(devkit_root, "data", "meta.mat")
meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] meta = sio.loadmat(metafile, squeeze_me=True)["synsets"]
nums_children = list(zip(*meta))[4] nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children) meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
if num_children == 0]
idcs, wnids, classes = list(zip(*meta))[:3] idcs, wnids, classes = list(zip(*meta))[:3]
classes = [tuple(clss.split(', ')) for clss in classes] classes = [tuple(clss.split(", ")) for clss in classes]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
return idx_to_wnid, wnid_to_classes return idx_to_wnid, wnid_to_classes
def parse_val_groundtruth_txt(devkit_root: str) -> List[int]: def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
file = os.path.join(devkit_root, "data", file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
"ILSVRC2012_validation_ground_truth.txt") with open(file, "r") as txtfh:
with open(file, 'r') as txtfh:
val_idcs = txtfh.readlines() val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs] return [int(val_idx) for val_idx in val_idcs]
......
from PIL import Image
import os import os
import os.path import os.path
from typing import Any, Callable, Dict, List, Optional, Union, Tuple from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from .vision import VisionDataset from PIL import Image
from .utils import download_and_extract_archive, verify_str_arg from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"] CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
DATASET_URLS = { DATASET_URLS = {
'2017': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz', "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
'2018': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz', "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
'2019': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz', "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
'2021_train': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz', "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
'2021_train_mini': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz', "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
'2021_valid': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz', "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
} }
DATASET_MD5 = { DATASET_MD5 = {
'2017': '7c784ea5e424efaec655bd392f87301f', "2017": "7c784ea5e424efaec655bd392f87301f",
'2018': 'b1c6952ce38f31868cc50ea72d066cc3', "2018": "b1c6952ce38f31868cc50ea72d066cc3",
'2019': 'c60a6e2962c9b8ccbd458d12c8582644', "2019": "c60a6e2962c9b8ccbd458d12c8582644",
'2021_train': '38a7bb733f7a09214d44293460ec0021', "2021_train": "38a7bb733f7a09214d44293460ec0021",
'2021_train_mini': 'db6ed8330e634445efc8fec83ae81442', "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
'2021_valid': 'f6f6e0e242e3d4c9569ba56400938afc', "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
} }
...@@ -63,27 +64,26 @@ class INaturalist(VisionDataset): ...@@ -63,27 +64,26 @@ class INaturalist(VisionDataset):
""" """
def __init__( def __init__(
self, self,
root: str, root: str,
version: str = "2021_train", version: str = "2021_train",
target_type: Union[List[str], str] = "full", target_type: Union[List[str], str] = "full",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
self.version = verify_str_arg(version, "version", DATASET_URLS.keys()) self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
super(INaturalist, self).__init__(os.path.join(root, version), super(INaturalist, self).__init__(
transform=transform, os.path.join(root, version), transform=transform, target_transform=target_transform
target_transform=target_transform) )
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
' You can use download=True to download it')
self.all_categories: List[str] = [] self.all_categories: List[str] = []
...@@ -96,12 +96,10 @@ class INaturalist(VisionDataset): ...@@ -96,12 +96,10 @@ class INaturalist(VisionDataset):
if not isinstance(target_type, list): if not isinstance(target_type, list):
target_type = [target_type] target_type = [target_type]
if self.version[:4] == "2021": if self.version[:4] == "2021":
self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
for t in target_type]
self._init_2021() self._init_2021()
else: else:
self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
for t in target_type]
self._init_pre2021() self._init_pre2021()
# index of all files: (full category id, filename) # index of all files: (full category id, filename)
...@@ -118,16 +116,14 @@ class INaturalist(VisionDataset): ...@@ -118,16 +116,14 @@ class INaturalist(VisionDataset):
self.all_categories = sorted(os.listdir(self.root)) self.all_categories = sorted(os.listdir(self.root))
# map: category type -> name of category -> index # map: category type -> name of category -> index
self.categories_index = { self.categories_index = {k: {} for k in CATEGORIES_2021}
k: {} for k in CATEGORIES_2021
}
for dir_index, dir_name in enumerate(self.all_categories): for dir_index, dir_name in enumerate(self.all_categories):
pieces = dir_name.split('_') pieces = dir_name.split("_")
if len(pieces) != 8: if len(pieces) != 8:
raise RuntimeError(f'Unexpected category name {dir_name}, wrong number of pieces') raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
if pieces[0] != f'{dir_index:05d}': if pieces[0] != f"{dir_index:05d}":
raise RuntimeError(f'Unexpected category id {pieces[0]}, expecting {dir_index:05d}') raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
cat_map = {} cat_map = {}
for cat, name in zip(CATEGORIES_2021, pieces[1:7]): for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
if name in self.categories_index[cat]: if name in self.categories_index[cat]:
...@@ -142,7 +138,7 @@ class INaturalist(VisionDataset): ...@@ -142,7 +138,7 @@ class INaturalist(VisionDataset):
"""Initialize based on 2017-2019 layout""" """Initialize based on 2017-2019 layout"""
# map: category type -> name of category -> index # map: category type -> name of category -> index
self.categories_index = {'super': {}} self.categories_index = {"super": {}}
cat_index = 0 cat_index = 0
super_categories = sorted(os.listdir(self.root)) super_categories = sorted(os.listdir(self.root))
...@@ -165,7 +161,7 @@ class INaturalist(VisionDataset): ...@@ -165,7 +161,7 @@ class INaturalist(VisionDataset):
self.all_categories.extend([""] * (subcat_i - old_len + 1)) self.all_categories.extend([""] * (subcat_i - old_len + 1))
if self.categories_map[subcat_i]: if self.categories_map[subcat_i]:
raise RuntimeError(f"Duplicate category {subcat}") raise RuntimeError(f"Duplicate category {subcat}")
self.categories_map[subcat_i] = {'super': sindex} self.categories_map[subcat_i] = {"super": sindex}
self.all_categories[subcat_i] = os.path.join(scat, subcat) self.all_categories[subcat_i] = os.path.join(scat, subcat)
# validate the dictionary # validate the dictionary
...@@ -183,9 +179,7 @@ class INaturalist(VisionDataset): ...@@ -183,9 +179,7 @@ class INaturalist(VisionDataset):
""" """
cat_id, fname = self.index[index] cat_id, fname = self.index[index]
img = Image.open(os.path.join(self.root, img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
self.all_categories[cat_id],
fname))
target: Any = [] target: Any = []
for t in self.target_type: for t in self.target_type:
...@@ -239,10 +233,8 @@ class INaturalist(VisionDataset): ...@@ -239,10 +233,8 @@ class INaturalist(VisionDataset):
base_root = os.path.dirname(self.root) base_root = os.path.dirname(self.root)
download_and_extract_archive( download_and_extract_archive(
DATASET_URLS[self.version], DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
base_root, )
filename=f"{self.version}.tgz",
md5=DATASET_MD5[self.version])
orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz")) orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
if not os.path.exists(orig_dir_name): if not os.path.exists(orig_dir_name):
......
import time import csv
import os import os
import time
import warnings import warnings
from os import path
import csv
from typing import Any, Callable, Dict, Optional, Tuple
from functools import partial from functools import partial
from multiprocessing import Pool from multiprocessing import Pool
from os import path
from typing import Any, Callable, Dict, Optional, Tuple
from torch import Tensor from torch import Tensor
from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity
from .folder import find_classes, make_dataset from .folder import find_classes, make_dataset
from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity
from .video_utils import VideoClips from .video_utils import VideoClips
from .vision import VisionDataset from .vision import VisionDataset
...@@ -214,18 +213,13 @@ class Kinetics(VisionDataset): ...@@ -214,18 +213,13 @@ class Kinetics(VisionDataset):
start=int(row["time_start"]), start=int(row["time_start"]),
end=int(row["time_end"]), end=int(row["time_end"]),
) )
label = ( label = row["label"].replace(" ", "_").replace("'", "").replace("(", "").replace(")", "")
row["label"]
.replace(" ", "_")
.replace("'", "")
.replace("(", "")
.replace(")", "")
)
os.makedirs(path.join(self.split_folder, label), exist_ok=True) os.makedirs(path.join(self.split_folder, label), exist_ok=True)
downloaded_file = path.join(self.split_folder, f) downloaded_file = path.join(self.split_folder, f)
if path.isfile(downloaded_file): if path.isfile(downloaded_file):
os.replace( os.replace(
downloaded_file, path.join(self.split_folder, label, f), downloaded_file,
path.join(self.split_folder, label, f),
) )
@property @property
...@@ -303,11 +297,12 @@ class Kinetics400(Kinetics): ...@@ -303,11 +297,12 @@ class Kinetics400(Kinetics):
split: Any = None, split: Any = None,
download: Any = None, download: Any = None,
num_download_workers: Any = None, num_download_workers: Any = None,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
warnings.warn( warnings.warn(
"Kinetics400 is deprecated and will be removed in a future release." "Kinetics400 is deprecated and will be removed in a future release."
"It was replaced by Kinetics(..., num_classes=\"400\").") 'It was replaced by Kinetics(..., num_classes="400").'
)
if any(value is not None for value in (num_classes, split, download, num_download_workers)): if any(value is not None for value in (num_classes, split, download, num_download_workers)):
raise RuntimeError( raise RuntimeError(
"Usage of 'num_classes', 'split', 'download', or 'num_download_workers' is not supported in " "Usage of 'num_classes', 'split', 'download', or 'num_download_workers' is not supported in "
......
...@@ -73,9 +73,7 @@ class Kitti(VisionDataset): ...@@ -73,9 +73,7 @@ class Kitti(VisionDataset):
if download: if download:
self.download() self.download()
if not self._check_exists(): if not self._check_exists():
raise RuntimeError( raise RuntimeError("Dataset not found. You may use download=True to download it.")
"Dataset not found. You may use download=True to download it."
)
image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name) image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
if self.train: if self.train:
...@@ -83,9 +81,7 @@ class Kitti(VisionDataset): ...@@ -83,9 +81,7 @@ class Kitti(VisionDataset):
for img_file in os.listdir(image_dir): for img_file in os.listdir(image_dir):
self.images.append(os.path.join(image_dir, img_file)) self.images.append(os.path.join(image_dir, img_file))
if self.train: if self.train:
self.targets.append( self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt"))
os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt")
)
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Get item at a given index. """Get item at a given index.
...@@ -117,16 +113,18 @@ class Kitti(VisionDataset): ...@@ -117,16 +113,18 @@ class Kitti(VisionDataset):
with open(self.targets[index]) as inp: with open(self.targets[index]) as inp:
content = csv.reader(inp, delimiter=" ") content = csv.reader(inp, delimiter=" ")
for line in content: for line in content:
target.append({ target.append(
"type": line[0], {
"truncated": float(line[1]), "type": line[0],
"occluded": int(line[2]), "truncated": float(line[1]),
"alpha": float(line[3]), "occluded": int(line[2]),
"bbox": [float(x) for x in line[4:8]], "alpha": float(line[3]),
"dimensions": [float(x) for x in line[8:11]], "bbox": [float(x) for x in line[4:8]],
"location": [float(x) for x in line[11:14]], "dimensions": [float(x) for x in line[8:11]],
"rotation_y": float(line[14]), "location": [float(x) for x in line[11:14]],
}) "rotation_y": float(line[14]),
}
)
return target return target
def __len__(self) -> int: def __len__(self) -> int:
...@@ -141,10 +139,7 @@ class Kitti(VisionDataset): ...@@ -141,10 +139,7 @@ class Kitti(VisionDataset):
folders = [self.image_dir_name] folders = [self.image_dir_name]
if self.train: if self.train:
folders.append(self.labels_dir_name) folders.append(self.labels_dir_name)
return all( return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders)
os.path.isdir(os.path.join(self._raw_folder, self._location, fname))
for fname in folders
)
def download(self) -> None: def download(self) -> None:
"""Download the KITTI data if it doesn't exist already.""" """Download the KITTI data if it doesn't exist already."""
......
import os import os
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple
from PIL import Image from PIL import Image
from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
from .vision import VisionDataset
class _LFW(VisionDataset): class _LFW(VisionDataset):
base_folder = 'lfw-py' base_folder = "lfw-py"
download_url_prefix = "http://vis-www.cs.umass.edu/lfw/" download_url_prefix = "http://vis-www.cs.umass.edu/lfw/"
file_dict = { file_dict = {
'original': ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"), "original": ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"),
'funneled': ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"), "funneled": ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"),
'deepfunneled': ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201") "deepfunneled": ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201"),
} }
checksums = { checksums = {
'pairs.txt': '9f1ba174e4e1c508ff7cdf10ac338a7d', "pairs.txt": "9f1ba174e4e1c508ff7cdf10ac338a7d",
'pairsDevTest.txt': '5132f7440eb68cf58910c8a45a2ac10b', "pairsDevTest.txt": "5132f7440eb68cf58910c8a45a2ac10b",
'pairsDevTrain.txt': '4f27cbf15b2da4a85c1907eb4181ad21', "pairsDevTrain.txt": "4f27cbf15b2da4a85c1907eb4181ad21",
'people.txt': '450f0863dd89e85e73936a6d71a3474b', "people.txt": "450f0863dd89e85e73936a6d71a3474b",
'peopleDevTest.txt': 'e4bf5be0a43b5dcd9dc5ccfcb8fb19c5', "peopleDevTest.txt": "e4bf5be0a43b5dcd9dc5ccfcb8fb19c5",
'peopleDevTrain.txt': '54eaac34beb6d042ed3a7d883e247a21', "peopleDevTrain.txt": "54eaac34beb6d042ed3a7d883e247a21",
'lfw-names.txt': 'a6d0a479bd074669f656265a6e693f6d' "lfw-names.txt": "a6d0a479bd074669f656265a6e693f6d",
} }
annot_file = {'10fold': '', 'train': 'DevTrain', 'test': 'DevTest'} annot_file = {"10fold": "", "train": "DevTrain", "test": "DevTest"}
names = "lfw-names.txt" names = "lfw-names.txt"
def __init__( def __init__(
...@@ -37,14 +39,15 @@ class _LFW(VisionDataset): ...@@ -37,14 +39,15 @@ class _LFW(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
): ):
super(_LFW, self).__init__(os.path.join(root, self.base_folder), super(_LFW, self).__init__(
transform=transform, target_transform=target_transform) os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform
)
self.image_set = verify_str_arg(image_set.lower(), 'image_set', self.file_dict.keys()) self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
images_dir, self.filename, self.md5 = self.file_dict[self.image_set] images_dir, self.filename, self.md5 = self.file_dict[self.image_set]
self.view = verify_str_arg(view.lower(), 'view', ['people', 'pairs']) self.view = verify_str_arg(view.lower(), "view", ["people", "pairs"])
self.split = verify_str_arg(split.lower(), 'split', ['10fold', 'train', 'test']) self.split = verify_str_arg(split.lower(), "split", ["10fold", "train", "test"])
self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt" self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt"
self.data: List[Any] = [] self.data: List[Any] = []
...@@ -52,15 +55,14 @@ class _LFW(VisionDataset): ...@@ -52,15 +55,14 @@ class _LFW(VisionDataset):
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
' You can use download=True to download it')
self.images_dir = os.path.join(self.root, images_dir) self.images_dir = os.path.join(self.root, images_dir)
def _loader(self, path: str) -> Image.Image: def _loader(self, path: str) -> Image.Image:
with open(path, 'rb') as f: with open(path, "rb") as f:
img = Image.open(f) img = Image.open(f)
return img.convert('RGB') return img.convert("RGB")
def _check_integrity(self): def _check_integrity(self):
st1 = check_integrity(os.path.join(self.root, self.filename), self.md5) st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
...@@ -73,7 +75,7 @@ class _LFW(VisionDataset): ...@@ -73,7 +75,7 @@ class _LFW(VisionDataset):
def download(self): def download(self):
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print("Files already downloaded and verified")
return return
url = f"{self.download_url_prefix}{self.filename}" url = f"{self.download_url_prefix}{self.filename}"
download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5) download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5)
...@@ -120,21 +122,20 @@ class LFWPeople(_LFW): ...@@ -120,21 +122,20 @@ class LFWPeople(_LFW):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
): ):
super(LFWPeople, self).__init__(root, split, image_set, "people", super(LFWPeople, self).__init__(root, split, image_set, "people", transform, target_transform, download)
transform, target_transform, download)
self.class_to_idx = self._get_classes() self.class_to_idx = self._get_classes()
self.data, self.targets = self._get_people() self.data, self.targets = self._get_people()
def _get_people(self): def _get_people(self):
data, targets = [], [] data, targets = [], []
with open(os.path.join(self.root, self.labels_file), 'r') as f: with open(os.path.join(self.root, self.labels_file), "r") as f:
lines = f.readlines() lines = f.readlines()
n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0) n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)
for fold in range(n_folds): for fold in range(n_folds):
n_lines = int(lines[s]) n_lines = int(lines[s])
people = [line.strip().split("\t") for line in lines[s + 1: s + n_lines + 1]] people = [line.strip().split("\t") for line in lines[s + 1 : s + n_lines + 1]]
s += n_lines + 1 s += n_lines + 1
for i, (identity, num_imgs) in enumerate(people): for i, (identity, num_imgs) in enumerate(people):
for num in range(1, int(num_imgs) + 1): for num in range(1, int(num_imgs) + 1):
...@@ -145,7 +146,7 @@ class LFWPeople(_LFW): ...@@ -145,7 +146,7 @@ class LFWPeople(_LFW):
return data, targets return data, targets
def _get_classes(self): def _get_classes(self):
with open(os.path.join(self.root, self.names), 'r') as f: with open(os.path.join(self.root, self.names), "r") as f:
lines = f.readlines() lines = f.readlines()
names = [line.strip().split()[0] for line in lines] names = [line.strip().split()[0] for line in lines]
class_to_idx = {name: i for i, name in enumerate(names)} class_to_idx = {name: i for i, name in enumerate(names)}
...@@ -203,14 +204,13 @@ class LFWPairs(_LFW): ...@@ -203,14 +204,13 @@ class LFWPairs(_LFW):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
): ):
super(LFWPairs, self).__init__(root, split, image_set, "pairs", super(LFWPairs, self).__init__(root, split, image_set, "pairs", transform, target_transform, download)
transform, target_transform, download)
self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir) self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
def _get_pairs(self, images_dir): def _get_pairs(self, images_dir):
pair_names, data, targets = [], [], [] pair_names, data, targets = [], [], []
with open(os.path.join(self.root, self.labels_file), 'r') as f: with open(os.path.join(self.root, self.labels_file), "r") as f:
lines = f.readlines() lines = f.readlines()
if self.split == "10fold": if self.split == "10fold":
n_folds, n_pairs = lines[0].split("\t") n_folds, n_pairs = lines[0].split("\t")
...@@ -220,9 +220,9 @@ class LFWPairs(_LFW): ...@@ -220,9 +220,9 @@ class LFWPairs(_LFW):
s = 1 s = 1
for fold in range(n_folds): for fold in range(n_folds):
matched_pairs = [line.strip().split("\t") for line in lines[s: s + n_pairs]] matched_pairs = [line.strip().split("\t") for line in lines[s : s + n_pairs]]
unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs: s + (2 * n_pairs)]] unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs : s + (2 * n_pairs)]]
s += (2 * n_pairs) s += 2 * n_pairs
for pair in matched_pairs: for pair in matched_pairs:
img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1 img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1
pair_names.append((pair[0], pair[0])) pair_names.append((pair[0], pair[0]))
......
from .vision import VisionDataset import io
from PIL import Image
import os import os
import os.path import os.path
import io import pickle
import string import string
from collections.abc import Iterable from collections.abc import Iterable
import pickle
from typing import Any, Callable, cast, List, Optional, Tuple, Union from typing import Any, Callable, cast, List, Optional, Tuple, Union
from PIL import Image
from .utils import verify_str_arg, iterable_to_str from .utils import verify_str_arg, iterable_to_str
from .vision import VisionDataset
class LSUNClass(VisionDataset): class LSUNClass(VisionDataset):
def __init__( def __init__(
self, root: str, transform: Optional[Callable] = None, self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
target_transform: Optional[Callable] = None
) -> None: ) -> None:
import lmdb import lmdb
super(LSUNClass, self).__init__(root, transform=transform,
target_transform=target_transform)
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, super(LSUNClass, self).__init__(root, transform=transform, target_transform=target_transform)
readahead=False, meminit=False)
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
with self.env.begin(write=False) as txn: with self.env.begin(write=False) as txn:
self.length = txn.stat()['entries'] self.length = txn.stat()["entries"]
cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters) cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters)
if os.path.isfile(cache_file): if os.path.isfile(cache_file):
self.keys = pickle.load(open(cache_file, "rb")) self.keys = pickle.load(open(cache_file, "rb"))
else: else:
...@@ -40,7 +40,7 @@ class LSUNClass(VisionDataset): ...@@ -40,7 +40,7 @@ class LSUNClass(VisionDataset):
buf = io.BytesIO() buf = io.BytesIO()
buf.write(imgbuf) buf.write(imgbuf)
buf.seek(0) buf.seek(0)
img = Image.open(buf).convert('RGB') img = Image.open(buf).convert("RGB")
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
...@@ -71,22 +71,19 @@ class LSUN(VisionDataset): ...@@ -71,22 +71,19 @@ class LSUN(VisionDataset):
""" """
def __init__( def __init__(
self, self,
root: str, root: str,
classes: Union[str, List[str]] = "train", classes: Union[str, List[str]] = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
) -> None: ) -> None:
super(LSUN, self).__init__(root, transform=transform, super(LSUN, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform)
self.classes = self._verify_classes(classes) self.classes = self._verify_classes(classes)
# for each class, create an LSUNClassDataset # for each class, create an LSUNClassDataset
self.dbs = [] self.dbs = []
for c in self.classes: for c in self.classes:
self.dbs.append(LSUNClass( self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform))
root=os.path.join(root, f"{c}_lmdb"),
transform=transform))
self.indices = [] self.indices = []
count = 0 count = 0
...@@ -97,35 +94,41 @@ class LSUN(VisionDataset): ...@@ -97,35 +94,41 @@ class LSUN(VisionDataset):
self.length = count self.length = count
def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]: def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]:
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', categories = [
'conference_room', 'dining_room', 'kitchen', "bedroom",
'living_room', 'restaurant', 'tower'] "bridge",
dset_opts = ['train', 'val', 'test'] "church_outdoor",
"classroom",
"conference_room",
"dining_room",
"kitchen",
"living_room",
"restaurant",
"tower",
]
dset_opts = ["train", "val", "test"]
try: try:
classes = cast(str, classes) classes = cast(str, classes)
verify_str_arg(classes, "classes", dset_opts) verify_str_arg(classes, "classes", dset_opts)
if classes == 'test': if classes == "test":
classes = [classes] classes = [classes]
else: else:
classes = [c + '_' + classes for c in categories] classes = [c + "_" + classes for c in categories]
except ValueError: except ValueError:
if not isinstance(classes, Iterable): if not isinstance(classes, Iterable):
msg = ("Expected type str or Iterable for argument classes, " msg = "Expected type str or Iterable for argument classes, " "but got type {}."
"but got type {}.")
raise ValueError(msg.format(type(classes))) raise ValueError(msg.format(type(classes)))
classes = list(classes) classes = list(classes)
msg_fmtstr_type = ("Expected type str for elements in argument classes, " msg_fmtstr_type = "Expected type str for elements in argument classes, " "but got type {}."
"but got type {}.")
for c in classes: for c in classes:
verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c))) verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
c_short = c.split('_') c_short = c.split("_")
category, dset_opt = '_'.join(c_short[:-1]), c_short[-1] category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
msg = msg_fmtstr.format(category, "LSUN class", msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
iterable_to_str(categories))
verify_str_arg(category, valid_values=categories, custom_msg=msg) verify_str_arg(category, valid_values=categories, custom_msg=msg)
msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
......
from .vision import VisionDataset import codecs
import warnings
from PIL import Image
import os import os
import os.path import os.path
import numpy as np import shutil
import torch
import codecs
import string import string
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError from urllib.error import URLError
import numpy as np
import torch
from PIL import Image
from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity
import shutil from .vision import VisionDataset
class MNIST(VisionDataset): class MNIST(VisionDataset):
...@@ -31,21 +33,31 @@ class MNIST(VisionDataset): ...@@ -31,21 +33,31 @@ class MNIST(VisionDataset):
""" """
mirrors = [ mirrors = [
'http://yann.lecun.com/exdb/mnist/', "http://yann.lecun.com/exdb/mnist/",
'https://ossci-datasets.s3.amazonaws.com/mnist/', "https://ossci-datasets.s3.amazonaws.com/mnist/",
] ]
resources = [ resources = [
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c") ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
] ]
training_file = 'training.pt' training_file = "training.pt"
test_file = 'test.pt' test_file = "test.pt"
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', classes = [
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] "0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
@property @property
def train_labels(self): def train_labels(self):
...@@ -68,15 +80,14 @@ class MNIST(VisionDataset): ...@@ -68,15 +80,14 @@ class MNIST(VisionDataset):
return self.data return self.data
def __init__( def __init__(
self, self,
root: str, root: str,
train: bool = True, train: bool = True,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(MNIST, self).__init__(root, transform=transform, super(MNIST, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform)
self.train = train # training set or test set self.train = train # training set or test set
if self._check_legacy_exist(): if self._check_legacy_exist():
...@@ -87,8 +98,7 @@ class MNIST(VisionDataset): ...@@ -87,8 +98,7 @@ class MNIST(VisionDataset):
self.download() self.download()
if not self._check_exists(): if not self._check_exists():
raise RuntimeError('Dataset not found.' + raise RuntimeError("Dataset not found." + " You can use download=True to download it")
' You can use download=True to download it')
self.data, self.targets = self._load_data() self.data, self.targets = self._load_data()
...@@ -128,7 +138,7 @@ class MNIST(VisionDataset): ...@@ -128,7 +138,7 @@ class MNIST(VisionDataset):
# doing this so that it is consistent with all other datasets # doing this so that it is consistent with all other datasets
# to return a PIL Image # to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L') img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
...@@ -143,11 +153,11 @@ class MNIST(VisionDataset): ...@@ -143,11 +153,11 @@ class MNIST(VisionDataset):
@property @property
def raw_folder(self) -> str: def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'raw') return os.path.join(self.root, self.__class__.__name__, "raw")
@property @property
def processed_folder(self) -> str: def processed_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'processed') return os.path.join(self.root, self.__class__.__name__, "processed")
@property @property
def class_to_idx(self) -> Dict[str, int]: def class_to_idx(self) -> Dict[str, int]:
...@@ -173,15 +183,9 @@ class MNIST(VisionDataset): ...@@ -173,15 +183,9 @@ class MNIST(VisionDataset):
url = "{}{}".format(mirror, filename) url = "{}{}".format(mirror, filename)
try: try:
print("Downloading {}".format(url)) print("Downloading {}".format(url))
download_and_extract_archive( download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
url, download_root=self.raw_folder,
filename=filename,
md5=md5
)
except URLError as error: except URLError as error:
print( print("Failed to download (trying next):\n{}".format(error))
"Failed to download (trying next):\n{}".format(error)
)
continue continue
finally: finally:
print() print()
...@@ -209,18 +213,16 @@ class FashionMNIST(MNIST): ...@@ -209,18 +213,16 @@ class FashionMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
mirrors = [
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
]
resources = [ resources = [
("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310") ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
] ]
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
class KMNIST(MNIST): class KMNIST(MNIST):
...@@ -239,17 +241,16 @@ class KMNIST(MNIST): ...@@ -239,17 +241,16 @@ class KMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
mirrors = [
"http://codh.rois.ac.jp/kmnist/dataset/kmnist/" mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
]
resources = [ resources = [
("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"), ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"), ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"), ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134") ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
] ]
classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo'] classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
class EMNIST(MNIST): class EMNIST(MNIST):
...@@ -271,19 +272,20 @@ class EMNIST(MNIST): ...@@ -271,19 +272,20 @@ class EMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
url = 'https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"
md5 = "58c8d27c78d21e728a6bc7b3cc06412e" md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
# Merged Classes assumes Same structure for both uppercase and lowercase version # Merged Classes assumes Same structure for both uppercase and lowercase version
_merged_classes = {'c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z'} _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
_all_classes = set(string.digits + string.ascii_letters) _all_classes = set(string.digits + string.ascii_letters)
classes_split_dict = { classes_split_dict = {
'byclass': sorted(list(_all_classes)), "byclass": sorted(list(_all_classes)),
'bymerge': sorted(list(_all_classes - _merged_classes)), "bymerge": sorted(list(_all_classes - _merged_classes)),
'balanced': sorted(list(_all_classes - _merged_classes)), "balanced": sorted(list(_all_classes - _merged_classes)),
'letters': ['N/A'] + list(string.ascii_lowercase), "letters": ["N/A"] + list(string.ascii_lowercase),
'digits': list(string.digits), "digits": list(string.digits),
'mnist': list(string.digits), "mnist": list(string.digits),
} }
def __init__(self, root: str, split: str, **kwargs: Any) -> None: def __init__(self, root: str, split: str, **kwargs: Any) -> None:
...@@ -295,11 +297,11 @@ class EMNIST(MNIST): ...@@ -295,11 +297,11 @@ class EMNIST(MNIST):
@staticmethod @staticmethod
def _training_file(split) -> str: def _training_file(split) -> str:
return 'training_{}.pt'.format(split) return "training_{}.pt".format(split)
@staticmethod @staticmethod
def _test_file(split) -> str: def _test_file(split) -> str:
return 'test_{}.pt'.format(split) return "test_{}.pt".format(split)
@property @property
def _file_prefix(self) -> str: def _file_prefix(self) -> str:
...@@ -328,9 +330,9 @@ class EMNIST(MNIST): ...@@ -328,9 +330,9 @@ class EMNIST(MNIST):
os.makedirs(self.raw_folder, exist_ok=True) os.makedirs(self.raw_folder, exist_ok=True)
download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
gzip_folder = os.path.join(self.raw_folder, 'gzip') gzip_folder = os.path.join(self.raw_folder, "gzip")
for gzip_file in os.listdir(gzip_folder): for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'): if gzip_file.endswith(".gz"):
extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
shutil.rmtree(gzip_folder) shutil.rmtree(gzip_folder)
...@@ -365,39 +367,60 @@ class QMNIST(MNIST): ...@@ -365,39 +367,60 @@ class QMNIST(MNIST):
training set ot the testing set. Default: True. training set ot the testing set. Default: True.
""" """
subsets = { subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
'train': 'train',
'test': 'test',
'test10k': 'test',
'test50k': 'test',
'nist': 'nist'
}
resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment] resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz', "train": [
'ed72d4157d28c017586c42bc6afe6370'), (
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz', "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
'0058f8dd561b90ffdd0f734c6a30e5e4')], "ed72d4157d28c017586c42bc6afe6370",
'test': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz', ),
'1394631089c404de565df7b7aeaf9412'), (
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz', "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
'5b5b05890a5e13444e108efe57b788aa')], "0058f8dd561b90ffdd0f734c6a30e5e4",
'nist': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz', ),
'7f124b3b8ab81486c9d8c2749c17f834'), ],
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz', "test": [
'5ed0e788978e45d4a8bd4b7caec3d79d')] (
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
"1394631089c404de565df7b7aeaf9412",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
"5b5b05890a5e13444e108efe57b788aa",
),
],
"nist": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
"7f124b3b8ab81486c9d8c2749c17f834",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
"5ed0e788978e45d4a8bd4b7caec3d79d",
),
],
} }
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', classes = [
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] "0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
def __init__( def __init__(
self, root: str, what: Optional[str] = None, compat: bool = True, self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
train: bool = True, **kwargs: Any
) -> None: ) -> None:
if what is None: if what is None:
what = 'train' if train else 'test' what = "train" if train else "test"
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
self.compat = compat self.compat = compat
self.data_file = what + '.pt' self.data_file = what + ".pt"
self.training_file = self.data_file self.training_file = self.data_file
self.test_file = self.data_file self.test_file = self.data_file
super(QMNIST, self).__init__(root, train, **kwargs) super(QMNIST, self).__init__(root, train, **kwargs)
...@@ -417,16 +440,16 @@ class QMNIST(MNIST): ...@@ -417,16 +440,16 @@ class QMNIST(MNIST):
def _load_data(self): def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file) data = read_sn3_pascalvincent_tensor(self.images_file)
assert (data.dtype == torch.uint8) assert data.dtype == torch.uint8
assert (data.ndimension() == 3) assert data.ndimension() == 3
targets = read_sn3_pascalvincent_tensor(self.labels_file).long() targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
assert (targets.ndimension() == 2) assert targets.ndimension() == 2
if self.what == 'test10k': if self.what == "test10k":
data = data[0:10000, :, :].clone() data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone() targets = targets[0:10000, :].clone()
elif self.what == 'test50k': elif self.what == "test50k":
data = data[10000:, :, :].clone() data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone() targets = targets[10000:, :].clone()
...@@ -434,7 +457,7 @@ class QMNIST(MNIST): ...@@ -434,7 +457,7 @@ class QMNIST(MNIST):
def download(self) -> None: def download(self) -> None:
"""Download the QMNIST data if it doesn't exist already. """Download the QMNIST data if it doesn't exist already.
Note that we only download what has been asked for (argument 'what'). Note that we only download what has been asked for (argument 'what').
""" """
if self._check_exists(): if self._check_exists():
return return
...@@ -443,7 +466,7 @@ class QMNIST(MNIST): ...@@ -443,7 +466,7 @@ class QMNIST(MNIST):
split = self.resources[self.subsets[self.what]] split = self.resources[self.subsets[self.what]]
for url, md5 in split: for url, md5 in split:
filename = url.rpartition('/')[2] filename = url.rpartition("/")[2]
file_path = os.path.join(self.raw_folder, filename) file_path = os.path.join(self.raw_folder, filename)
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5) download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5)
...@@ -451,7 +474,7 @@ class QMNIST(MNIST): ...@@ -451,7 +474,7 @@ class QMNIST(MNIST):
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag # redefined to handle the compat flag
img, target = self.data[index], self.targets[index] img, target = self.data[index], self.targets[index]
img = Image.fromarray(img.numpy(), mode='L') img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if self.compat: if self.compat:
...@@ -465,22 +488,22 @@ class QMNIST(MNIST): ...@@ -465,22 +488,22 @@ class QMNIST(MNIST):
def get_int(b: bytes) -> int: def get_int(b: bytes) -> int:
return int(codecs.encode(b, 'hex'), 16) return int(codecs.encode(b, "hex"), 16)
SN3_PASCALVINCENT_TYPEMAP = { SN3_PASCALVINCENT_TYPEMAP = {
8: (torch.uint8, np.uint8, np.uint8), 8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8), 9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype('>i2'), 'i2'), 11: (torch.int16, np.dtype(">i2"), "i2"),
12: (torch.int32, np.dtype('>i4'), 'i4'), 12: (torch.int32, np.dtype(">i4"), "i4"),
13: (torch.float32, np.dtype('>f4'), 'f4'), 13: (torch.float32, np.dtype(">f4"), "f4"),
14: (torch.float64, np.dtype('>f8'), 'f8') 14: (torch.float64, np.dtype(">f8"), "f8"),
} }
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object. Argument may be a filename, compressed filename, or file object.
""" """
# read # read
with open(path, "rb") as f: with open(path, "rb") as f:
...@@ -492,7 +515,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso ...@@ -492,7 +515,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
assert 1 <= nd <= 3 assert 1 <= nd <= 3
assert 8 <= ty <= 14 assert 8 <= ty <= 14
m = SN3_PASCALVINCENT_TYPEMAP[ty] m = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
assert parsed.shape[0] == np.prod(s) or not strict assert parsed.shape[0] == np.prod(s) or not strict
return torch.from_numpy(parsed.astype(m[2])).view(*s) return torch.from_numpy(parsed.astype(m[2])).view(*s)
...@@ -500,13 +523,13 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso ...@@ -500,13 +523,13 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
def read_label_file(path: str) -> torch.Tensor: def read_label_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False) x = read_sn3_pascalvincent_tensor(path, strict=False)
assert(x.dtype == torch.uint8) assert x.dtype == torch.uint8
assert(x.ndimension() == 1) assert x.ndimension() == 1
return x.long() return x.long()
def read_image_file(path: str) -> torch.Tensor: def read_image_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False) x = read_sn3_pascalvincent_tensor(path, strict=False)
assert(x.dtype == torch.uint8) assert x.dtype == torch.uint8
assert(x.ndimension() == 3) assert x.ndimension() == 3
return x return x
from PIL import Image
from os.path import join from os.path import join
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple
from .vision import VisionDataset
from PIL import Image
from .utils import download_and_extract_archive, check_integrity, list_dir, list_files from .utils import download_and_extract_archive, check_integrity, list_dir, list_files
from .vision import VisionDataset
class Omniglot(VisionDataset): class Omniglot(VisionDataset):
...@@ -21,38 +23,40 @@ class Omniglot(VisionDataset): ...@@ -21,38 +23,40 @@ class Omniglot(VisionDataset):
puts it in root directory. If the zip files are already downloaded, they are not puts it in root directory. If the zip files are already downloaded, they are not
downloaded again. downloaded again.
""" """
folder = 'omniglot-py'
download_url_prefix = 'https://raw.githubusercontent.com/brendenlake/omniglot/master/python' folder = "omniglot-py"
download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
zips_md5 = { zips_md5 = {
'images_background': '68d2efa1b9178cc56df9314c21c6e718', "images_background": "68d2efa1b9178cc56df9314c21c6e718",
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811' "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
} }
def __init__( def __init__(
self, self,
root: str, root: str,
background: bool = True, background: bool = True,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(Omniglot, self).__init__(join(root, self.folder), transform=transform, super(Omniglot, self).__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
target_transform=target_transform)
self.background = background self.background = background
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
' You can use download=True to download it')
self.target_folder = join(self.root, self._get_target_folder()) self.target_folder = join(self.root, self._get_target_folder())
self._alphabets = list_dir(self.target_folder) self._alphabets = list_dir(self.target_folder)
self._characters: List[str] = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] self._characters: List[str] = sum(
for a in self._alphabets], []) [[join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets], []
self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')] )
for idx, character in enumerate(self._characters)] self._character_images = [
[(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
for idx, character in enumerate(self._characters)
]
self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, []) self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
def __len__(self) -> int: def __len__(self) -> int:
...@@ -68,7 +72,7 @@ class Omniglot(VisionDataset): ...@@ -68,7 +72,7 @@ class Omniglot(VisionDataset):
""" """
image_name, character_class = self._flat_character_images[index] image_name, character_class = self._flat_character_images[index]
image_path = join(self.target_folder, self._characters[character_class], image_name) image_path = join(self.target_folder, self._characters[character_class], image_name)
image = Image.open(image_path, mode='r').convert('L') image = Image.open(image_path, mode="r").convert("L")
if self.transform: if self.transform:
image = self.transform(image) image = self.transform(image)
...@@ -80,19 +84,19 @@ class Omniglot(VisionDataset): ...@@ -80,19 +84,19 @@ class Omniglot(VisionDataset):
def _check_integrity(self) -> bool: def _check_integrity(self) -> bool:
zip_filename = self._get_target_folder() zip_filename = self._get_target_folder()
if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]): if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
return False return False
return True return True
def download(self) -> None: def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print("Files already downloaded and verified")
return return
filename = self._get_target_folder() filename = self._get_target_folder()
zip_filename = filename + '.zip' zip_filename = filename + ".zip"
url = self.download_url_prefix + '/' + zip_filename url = self.download_url_prefix + "/" + zip_filename
download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename]) download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
def _get_target_folder(self) -> str: def _get_target_folder(self) -> str:
return 'images_background' if self.background else 'images_evaluation' return "images_background" if self.background else "images_evaluation"
import os import os
import numpy as np
from PIL import Image
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch import torch
from .vision import VisionDataset from PIL import Image
from .utils import download_url from .utils import download_url
from .vision import VisionDataset
class PhotoTour(VisionDataset): class PhotoTour(VisionDataset):
...@@ -33,56 +33,67 @@ class PhotoTour(VisionDataset): ...@@ -33,56 +33,67 @@ class PhotoTour(VisionDataset):
downloaded again. downloaded again.
""" """
urls = { urls = {
'notredame_harris': [ "notredame_harris": [
'http://matthewalunbrown.com/patchdata/notredame_harris.zip', "http://matthewalunbrown.com/patchdata/notredame_harris.zip",
'notredame_harris.zip', "notredame_harris.zip",
'69f8c90f78e171349abdf0307afefe4d' "69f8c90f78e171349abdf0307afefe4d",
],
'yosemite_harris': [
'http://matthewalunbrown.com/patchdata/yosemite_harris.zip',
'yosemite_harris.zip',
'a73253d1c6fbd3ba2613c45065c00d46'
], ],
'liberty_harris': [ "yosemite_harris": [
'http://matthewalunbrown.com/patchdata/liberty_harris.zip', "http://matthewalunbrown.com/patchdata/yosemite_harris.zip",
'liberty_harris.zip', "yosemite_harris.zip",
'c731fcfb3abb4091110d0ae8c7ba182c' "a73253d1c6fbd3ba2613c45065c00d46",
], ],
'notredame': [ "liberty_harris": [
'http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip', "http://matthewalunbrown.com/patchdata/liberty_harris.zip",
'notredame.zip', "liberty_harris.zip",
'509eda8535847b8c0a90bbb210c83484' "c731fcfb3abb4091110d0ae8c7ba182c",
], ],
'yosemite': [ "notredame": [
'http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip', "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip",
'yosemite.zip', "notredame.zip",
'533b2e8eb7ede31be40abc317b2fd4f0' "509eda8535847b8c0a90bbb210c83484",
],
'liberty': [
'http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip',
'liberty.zip',
'fdd9152f138ea5ef2091746689176414'
], ],
"yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"],
"liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"],
}
means = {
"notredame": 0.4854,
"yosemite": 0.4844,
"liberty": 0.4437,
"notredame_harris": 0.4854,
"yosemite_harris": 0.4844,
"liberty_harris": 0.4437,
}
stds = {
"notredame": 0.1864,
"yosemite": 0.1818,
"liberty": 0.2019,
"notredame_harris": 0.1864,
"yosemite_harris": 0.1818,
"liberty_harris": 0.2019,
}
lens = {
"notredame": 468159,
"yosemite": 633587,
"liberty": 450092,
"liberty_harris": 379587,
"yosemite_harris": 450912,
"notredame_harris": 325295,
} }
means = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437, image_ext = "bmp"
'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437} info_file = "info.txt"
stds = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019, matches_files = "m50_100000_100000_0.txt"
'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019}
lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092,
'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295}
image_ext = 'bmp'
info_file = 'info.txt'
matches_files = 'm50_100000_100000_0.txt'
def __init__( def __init__(
self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
) -> None: ) -> None:
super(PhotoTour, self).__init__(root, transform=transform) super(PhotoTour, self).__init__(root, transform=transform)
self.name = name self.name = name
self.data_dir = os.path.join(self.root, name) self.data_dir = os.path.join(self.root, name)
self.data_down = os.path.join(self.root, '{}.zip'.format(name)) self.data_down = os.path.join(self.root, "{}.zip".format(name))
self.data_file = os.path.join(self.root, '{}.pt'.format(name)) self.data_file = os.path.join(self.root, "{}.pt".format(name))
self.train = train self.train = train
self.mean = self.means[name] self.mean = self.means[name]
...@@ -128,7 +139,7 @@ class PhotoTour(VisionDataset): ...@@ -128,7 +139,7 @@ class PhotoTour(VisionDataset):
def download(self) -> None: def download(self) -> None:
if self._check_datafile_exists(): if self._check_datafile_exists():
print('# Found cached data {}'.format(self.data_file)) print("# Found cached data {}".format(self.data_file))
return return
if not self._check_downloaded(): if not self._check_downloaded():
...@@ -140,25 +151,26 @@ class PhotoTour(VisionDataset): ...@@ -140,25 +151,26 @@ class PhotoTour(VisionDataset):
download_url(url, self.root, filename, md5) download_url(url, self.root, filename, md5)
print('# Extracting data {}\n'.format(self.data_down)) print("# Extracting data {}\n".format(self.data_down))
import zipfile import zipfile
with zipfile.ZipFile(fpath, 'r') as z:
with zipfile.ZipFile(fpath, "r") as z:
z.extractall(self.data_dir) z.extractall(self.data_dir)
os.unlink(fpath) os.unlink(fpath)
def cache(self) -> None: def cache(self) -> None:
# process and save as torch files # process and save as torch files
print('# Caching data {}'.format(self.data_file)) print("# Caching data {}".format(self.data_file))
dataset = ( dataset = (
read_image_file(self.data_dir, self.image_ext, self.lens[self.name]), read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
read_info_file(self.data_dir, self.info_file), read_info_file(self.data_dir, self.info_file),
read_matches_files(self.data_dir, self.matches_files) read_matches_files(self.data_dir, self.matches_files),
) )
with open(self.data_file, 'wb') as f: with open(self.data_file, "wb") as f:
torch.save(dataset, f) torch.save(dataset, f)
def extra_repr(self) -> str: def extra_repr(self) -> str:
...@@ -166,17 +178,14 @@ class PhotoTour(VisionDataset): ...@@ -166,17 +178,14 @@ class PhotoTour(VisionDataset):
def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor: def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
"""Return a Tensor containing the patches """Return a Tensor containing the patches"""
"""
def PIL2array(_img: Image.Image) -> np.ndarray: def PIL2array(_img: Image.Image) -> np.ndarray:
"""Convert PIL image type to numpy 2D array """Convert PIL image type to numpy 2D array"""
"""
return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64) return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)
def find_files(_data_dir: str, _image_ext: str) -> List[str]: def find_files(_data_dir: str, _image_ext: str) -> List[str]:
"""Return a list with the file names of the images containing the patches """Return a list with the file names of the images containing the patches"""
"""
files = [] files = []
# find those files with the specified extension # find those files with the specified extension
for file_dir in os.listdir(_data_dir): for file_dir in os.listdir(_data_dir):
...@@ -198,22 +207,21 @@ def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor: ...@@ -198,22 +207,21 @@ def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
def read_info_file(data_dir: str, info_file: str) -> torch.Tensor: def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
"""Return a Tensor containing the list of labels """Return a Tensor containing the list of labels
Read the file and keep only the ID of the 3D point. Read the file and keep only the ID of the 3D point.
""" """
with open(os.path.join(data_dir, info_file), 'r') as f: with open(os.path.join(data_dir, info_file), "r") as f:
labels = [int(line.split()[0]) for line in f] labels = [int(line.split()[0]) for line in f]
return torch.LongTensor(labels) return torch.LongTensor(labels)
def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor: def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
"""Return a Tensor containing the ground truth matches """Return a Tensor containing the ground truth matches
Read the file and keep only 3D point ID. Read the file and keep only 3D point ID.
Matches are represented with a 1, non matches with a 0. Matches are represented with a 1, non matches with a 0.
""" """
matches = [] matches = []
with open(os.path.join(data_dir, matches_file), 'r') as f: with open(os.path.join(data_dir, matches_file), "r") as f:
for line in f: for line in f:
line_split = line.split() line_split = line.split()
matches.append([int(line_split[0]), int(line_split[3]), matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
int(line_split[1] == line_split[4])])
return torch.LongTensor(matches) return torch.LongTensor(matches)
from .clip_sampler import DistributedSampler, UniformClipSampler, RandomClipSampler from .clip_sampler import DistributedSampler, UniformClipSampler, RandomClipSampler
__all__ = ('DistributedSampler', 'UniformClipSampler', 'RandomClipSampler') __all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler")
import math import math
from typing import Optional, List, Iterator, Sized, Union, cast
import torch import torch
from torch.utils.data import Sampler
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data import Sampler
from torchvision.datasets.video_utils import VideoClips from torchvision.datasets.video_utils import VideoClips
from typing import Optional, List, Iterator, Sized, Union, cast
class DistributedSampler(Sampler): class DistributedSampler(Sampler):
...@@ -36,12 +37,12 @@ class DistributedSampler(Sampler): ...@@ -36,12 +37,12 @@ class DistributedSampler(Sampler):
""" """
def __init__( def __init__(
self, self,
dataset: Sized, dataset: Sized,
num_replicas: Optional[int] = None, num_replicas: Optional[int] = None,
rank: Optional[int] = None, rank: Optional[int] = None,
shuffle: bool = False, shuffle: bool = False,
group_size: int = 1, group_size: int = 1,
) -> None: ) -> None:
if num_replicas is None: if num_replicas is None:
if not dist.is_available(): if not dist.is_available():
...@@ -51,9 +52,11 @@ class DistributedSampler(Sampler): ...@@ -51,9 +52,11 @@ class DistributedSampler(Sampler):
if not dist.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank() rank = dist.get_rank()
assert len(dataset) % group_size == 0, ( assert (
"dataset length must be a multiplier of group size" len(dataset) % group_size == 0
"dataset length: %d, group size: %d" % (len(dataset), group_size) ), "dataset length must be a multiplier of group size" "dataset length: %d, group size: %d" % (
len(dataset),
group_size,
) )
self.dataset = dataset self.dataset = dataset
self.group_size = group_size self.group_size = group_size
...@@ -61,9 +64,7 @@ class DistributedSampler(Sampler): ...@@ -61,9 +64,7 @@ class DistributedSampler(Sampler):
self.rank = rank self.rank = rank
self.epoch = 0 self.epoch = 0
dataset_group_length = len(dataset) // group_size dataset_group_length = len(dataset) // group_size
self.num_group_samples = int( self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas))
math.ceil(dataset_group_length * 1.0 / self.num_replicas)
)
self.num_samples = self.num_group_samples * group_size self.num_samples = self.num_group_samples * group_size
self.total_size = self.num_samples * self.num_replicas self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle self.shuffle = shuffle
...@@ -79,16 +80,14 @@ class DistributedSampler(Sampler): ...@@ -79,16 +80,14 @@ class DistributedSampler(Sampler):
indices = list(range(len(self.dataset))) indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible # add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))] indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size assert len(indices) == self.total_size
total_group_size = self.total_size // self.group_size total_group_size = self.total_size // self.group_size
indices = torch.reshape( indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size))
torch.LongTensor(indices), (total_group_size, self.group_size)
)
# subsample # subsample
indices = indices[self.rank:total_group_size:self.num_replicas, :] indices = indices[self.rank : total_group_size : self.num_replicas, :]
indices = torch.reshape(indices, (-1,)).tolist() indices = torch.reshape(indices, (-1,)).tolist()
assert len(indices) == self.num_samples assert len(indices) == self.num_samples
...@@ -115,10 +114,10 @@ class UniformClipSampler(Sampler): ...@@ -115,10 +114,10 @@ class UniformClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from video_clips (VideoClips): video clips to sample from
num_clips_per_video (int): number of clips to be sampled per video num_clips_per_video (int): number of clips to be sampled per video
""" """
def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None: def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
if not isinstance(video_clips, VideoClips): if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, " raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips)))
"got {}".format(type(video_clips)))
self.video_clips = video_clips self.video_clips = video_clips
self.num_clips_per_video = num_clips_per_video self.num_clips_per_video = num_clips_per_video
...@@ -132,19 +131,13 @@ class UniformClipSampler(Sampler): ...@@ -132,19 +131,13 @@ class UniformClipSampler(Sampler):
# corner case where video decoding fails # corner case where video decoding fails
continue continue
sampled = ( sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64)
torch.linspace(s, s + length - 1, steps=self.num_clips_per_video)
.floor()
.to(torch.int64)
)
s += length s += length
idxs.append(sampled) idxs.append(sampled)
return iter(cast(List[int], torch.cat(idxs).tolist())) return iter(cast(List[int], torch.cat(idxs).tolist()))
def __len__(self) -> int: def __len__(self) -> int:
return sum( return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0)
self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0
)
class RandomClipSampler(Sampler): class RandomClipSampler(Sampler):
...@@ -155,10 +148,10 @@ class RandomClipSampler(Sampler): ...@@ -155,10 +148,10 @@ class RandomClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video max_clips_per_video (int): maximum number of clips to be sampled per video
""" """
def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None: def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
if not isinstance(video_clips, VideoClips): if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, " raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips)))
"got {}".format(type(video_clips)))
self.video_clips = video_clips self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video self.max_clips_per_video = max_clips_per_video
......
import os import os
import shutil import shutil
from .vision import VisionDataset
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from .utils import download_url, verify_str_arg, download_and_extract_archive from .utils import download_url, verify_str_arg, download_and_extract_archive
from .vision import VisionDataset
class SBDataset(VisionDataset): class SBDataset(VisionDataset):
...@@ -50,30 +50,29 @@ class SBDataset(VisionDataset): ...@@ -50,30 +50,29 @@ class SBDataset(VisionDataset):
voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722" voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
def __init__( def __init__(
self, self,
root: str, root: str,
image_set: str = "train", image_set: str = "train",
mode: str = "boundaries", mode: str = "boundaries",
download: bool = False, download: bool = False,
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
) -> None: ) -> None:
try: try:
from scipy.io import loadmat from scipy.io import loadmat
self._loadmat = loadmat self._loadmat = loadmat
except ImportError: except ImportError:
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: " raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: " "pip install scipy")
"pip install scipy")
super(SBDataset, self).__init__(root, transforms) super(SBDataset, self).__init__(root, transforms)
self.image_set = verify_str_arg(image_set, "image_set", self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval"))
("train", "val", "train_noval"))
self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries")) self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
self.num_classes = 20 self.num_classes = 20
sbd_root = self.root sbd_root = self.root
image_dir = os.path.join(sbd_root, 'img') image_dir = os.path.join(sbd_root, "img")
mask_dir = os.path.join(sbd_root, 'cls') mask_dir = os.path.join(sbd_root, "cls")
if download: if download:
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
...@@ -81,36 +80,35 @@ class SBDataset(VisionDataset): ...@@ -81,36 +80,35 @@ class SBDataset(VisionDataset):
for f in ["cls", "img", "inst", "train.txt", "val.txt"]: for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
old_path = os.path.join(extracted_ds_root, f) old_path = os.path.join(extracted_ds_root, f)
shutil.move(old_path, sbd_root) shutil.move(old_path, sbd_root)
download_url(self.voc_train_url, sbd_root, self.voc_split_filename, download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5)
self.voc_split_md5)
if not os.path.isdir(sbd_root): if not os.path.isdir(sbd_root):
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
' You can use download=True to download it')
split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt') split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f), "r") as fh: with open(os.path.join(split_f), "r") as fh:
file_names = [x.strip() for x in fh.readlines()] file_names = [x.strip() for x in fh.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
assert (len(self.images) == len(self.masks)) assert len(self.images) == len(self.masks)
self._get_target = self._get_segmentation_target \ self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target
if self.mode == "segmentation" else self._get_boundaries_target
def _get_segmentation_target(self, filepath: str) -> Image.Image: def _get_segmentation_target(self, filepath: str) -> Image.Image:
mat = self._loadmat(filepath) mat = self._loadmat(filepath)
return Image.fromarray(mat['GTcls'][0]['Segmentation'][0]) return Image.fromarray(mat["GTcls"][0]["Segmentation"][0])
def _get_boundaries_target(self, filepath: str) -> np.ndarray: def _get_boundaries_target(self, filepath: str) -> np.ndarray:
mat = self._loadmat(filepath) mat = self._loadmat(filepath)
return np.concatenate([np.expand_dims(mat['GTcls'][0]['Boundaries'][0][i][0].toarray(), axis=0) return np.concatenate(
for i in range(self.num_classes)], axis=0) [np.expand_dims(mat["GTcls"][0]["Boundaries"][0][i][0].toarray(), axis=0) for i in range(self.num_classes)],
axis=0,
)
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
img = Image.open(self.images[index]).convert('RGB') img = Image.open(self.images[index]).convert("RGB")
target = self._get_target(self.masks[index]) target = self._get_target(self.masks[index])
if self.transforms is not None: if self.transforms is not None:
...@@ -123,4 +121,4 @@ class SBDataset(VisionDataset): ...@@ -123,4 +121,4 @@ class SBDataset(VisionDataset):
def extra_repr(self) -> str: def extra_repr(self) -> str:
lines = ["Image set: {image_set}", "Mode: {mode}"] lines = ["Image set: {image_set}", "Mode: {mode}"]
return '\n'.join(lines).format(**self.__dict__) return "\n".join(lines).format(**self.__dict__)
from PIL import Image import os
from .utils import download_url, check_integrity
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
import os from PIL import Image
from .utils import download_url, check_integrity
from .vision import VisionDataset from .vision import VisionDataset
...@@ -20,38 +21,37 @@ class SBU(VisionDataset): ...@@ -20,38 +21,37 @@ class SBU(VisionDataset):
puts it in root directory. If dataset is already downloaded, it is not puts it in root directory. If dataset is already downloaded, it is not
downloaded again. downloaded again.
""" """
url = "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz" url = "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
filename = "SBUCaptionedPhotoDataset.tar.gz" filename = "SBUCaptionedPhotoDataset.tar.gz"
md5_checksum = '9aec147b3488753cf758b4d493422285' md5_checksum = "9aec147b3488753cf758b4d493422285"
def __init__( def __init__(
self, self,
root: str, root: str,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = True, download: bool = True,
) -> None: ) -> None:
super(SBU, self).__init__(root, transform=transform, super(SBU, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform)
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
' You can use download=True to download it')
# Read the caption for each photo # Read the caption for each photo
self.photos = [] self.photos = []
self.captions = [] self.captions = []
file1 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt') file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
file2 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_captions.txt') file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
for line1, line2 in zip(open(file1), open(file2)): for line1, line2 in zip(open(file1), open(file2)):
url = line1.rstrip() url = line1.rstrip()
photo = os.path.basename(url) photo = os.path.basename(url)
filename = os.path.join(self.root, 'dataset', photo) filename = os.path.join(self.root, "dataset", photo)
if os.path.exists(filename): if os.path.exists(filename):
caption = line2.rstrip() caption = line2.rstrip()
self.photos.append(photo) self.photos.append(photo)
...@@ -65,8 +65,8 @@ class SBU(VisionDataset): ...@@ -65,8 +65,8 @@ class SBU(VisionDataset):
Returns: Returns:
tuple: (image, target) where target is a caption for the photo. tuple: (image, target) where target is a caption for the photo.
""" """
filename = os.path.join(self.root, 'dataset', self.photos[index]) filename = os.path.join(self.root, "dataset", self.photos[index])
img = Image.open(filename).convert('RGB') img = Image.open(filename).convert("RGB")
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
...@@ -93,21 +93,21 @@ class SBU(VisionDataset): ...@@ -93,21 +93,21 @@ class SBU(VisionDataset):
import tarfile import tarfile
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print("Files already downloaded and verified")
return return
download_url(self.url, self.root, self.filename, self.md5_checksum) download_url(self.url, self.root, self.filename, self.md5_checksum)
# Extract file # Extract file
with tarfile.open(os.path.join(self.root, self.filename), 'r:gz') as tar: with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root) tar.extractall(path=self.root)
# Download individual photos # Download individual photos
with open(os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')) as fh: with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
for line in fh: for line in fh:
url = line.rstrip() url = line.rstrip()
try: try:
download_url(url, os.path.join(self.root, 'dataset')) download_url(url, os.path.join(self.root, "dataset"))
except OSError: except OSError:
# The images point to public images on Flickr. # The images point to public images on Flickr.
# Note: Images might be removed by users at anytime. # Note: Images might be removed by users at anytime.
......
from PIL import Image
import os import os
import os.path import os.path
import numpy as np
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset
import numpy as np
from PIL import Image
from .utils import download_url, check_integrity from .utils import download_url, check_integrity
from .vision import VisionDataset
class SEMEION(VisionDataset): class SEMEION(VisionDataset):
...@@ -24,30 +26,28 @@ class SEMEION(VisionDataset): ...@@ -24,30 +26,28 @@ class SEMEION(VisionDataset):
""" """
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data" url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data"
filename = "semeion.data" filename = "semeion.data"
md5_checksum = 'cb545d371d2ce14ec121470795a77432' md5_checksum = "cb545d371d2ce14ec121470795a77432"
def __init__( def __init__(
self, self,
root: str, root: str,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = True, download: bool = True,
) -> None: ) -> None:
super(SEMEION, self).__init__(root, transform=transform, super(SEMEION, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform)
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
' You can use download=True to download it')
fp = os.path.join(self.root, self.filename) fp = os.path.join(self.root, self.filename)
data = np.loadtxt(fp) data = np.loadtxt(fp)
# convert value to 8 bit unsigned integer # convert value to 8 bit unsigned integer
# color (white #255) the pixels # color (white #255) the pixels
self.data = (data[:, :256] * 255).astype('uint8') self.data = (data[:, :256] * 255).astype("uint8")
self.data = np.reshape(self.data, (-1, 16, 16)) self.data = np.reshape(self.data, (-1, 16, 16))
self.labels = np.nonzero(data[:, 256:])[1] self.labels = np.nonzero(data[:, 256:])[1]
...@@ -63,7 +63,7 @@ class SEMEION(VisionDataset): ...@@ -63,7 +63,7 @@ class SEMEION(VisionDataset):
# doing this so that it is consistent with all other datasets # doing this so that it is consistent with all other datasets
# to return a PIL Image # to return a PIL Image
img = Image.fromarray(img, mode='L') img = Image.fromarray(img, mode="L")
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
...@@ -85,7 +85,7 @@ class SEMEION(VisionDataset): ...@@ -85,7 +85,7 @@ class SEMEION(VisionDataset):
def download(self) -> None: def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print("Files already downloaded and verified")
return return
root = self.root root = self.root
......
from PIL import Image
import os import os
import os.path import os.path
import numpy as np
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset import numpy as np
from PIL import Image
from .utils import check_integrity, download_and_extract_archive, verify_str_arg from .utils import check_integrity, download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
class STL10(VisionDataset): class STL10(VisionDataset):
...@@ -27,70 +28,60 @@ class STL10(VisionDataset): ...@@ -27,70 +28,60 @@ class STL10(VisionDataset):
puts it in root directory. If dataset is already downloaded, it is not puts it in root directory. If dataset is already downloaded, it is not
downloaded again. downloaded again.
""" """
base_folder = 'stl10_binary'
base_folder = "stl10_binary"
url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz" url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
filename = "stl10_binary.tar.gz" filename = "stl10_binary.tar.gz"
tgz_md5 = '91f7769df0f17e558f3565bffb0c7dfb' tgz_md5 = "91f7769df0f17e558f3565bffb0c7dfb"
class_names_file = 'class_names.txt' class_names_file = "class_names.txt"
folds_list_file = 'fold_indices.txt' folds_list_file = "fold_indices.txt"
train_list = [ train_list = [
['train_X.bin', '918c2871b30a85fa023e0c44e0bee87f'], ["train_X.bin", "918c2871b30a85fa023e0c44e0bee87f"],
['train_y.bin', '5a34089d4802c674881badbb80307741'], ["train_y.bin", "5a34089d4802c674881badbb80307741"],
['unlabeled_X.bin', '5242ba1fed5e4be9e1e742405eb56ca4'] ["unlabeled_X.bin", "5242ba1fed5e4be9e1e742405eb56ca4"],
] ]
test_list = [ test_list = [["test_X.bin", "7f263ba9f9e0b06b93213547f721ac82"], ["test_y.bin", "36f9794fa4beb8a2c72628de14fa638e"]]
['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'], splits = ("train", "train+unlabeled", "unlabeled", "test")
['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e']
]
splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
def __init__( def __init__(
self, self,
root: str, root: str,
split: str = "train", split: str = "train",
folds: Optional[int] = None, folds: Optional[int] = None,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(STL10, self).__init__(root, transform=transform, super(STL10, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform)
self.split = verify_str_arg(split, "split", self.splits) self.split = verify_str_arg(split, "split", self.splits)
self.folds = self._verify_folds(folds) self.folds = self._verify_folds(folds)
if download: if download:
self.download() self.download()
elif not self._check_integrity(): elif not self._check_integrity():
raise RuntimeError( raise RuntimeError("Dataset not found or corrupted. " "You can use download=True to download it")
'Dataset not found or corrupted. '
'You can use download=True to download it')
# now load the picked numpy arrays # now load the picked numpy arrays
self.labels: Optional[np.ndarray] self.labels: Optional[np.ndarray]
if self.split == 'train': if self.split == "train":
self.data, self.labels = self.__loadfile( self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
self.train_list[0][0], self.train_list[1][0])
self.__load_folds(folds) self.__load_folds(folds)
elif self.split == 'train+unlabeled': elif self.split == "train+unlabeled":
self.data, self.labels = self.__loadfile( self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
self.train_list[0][0], self.train_list[1][0])
self.__load_folds(folds) self.__load_folds(folds)
unlabeled_data, _ = self.__loadfile(self.train_list[2][0]) unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
self.data = np.concatenate((self.data, unlabeled_data)) self.data = np.concatenate((self.data, unlabeled_data))
self.labels = np.concatenate( self.labels = np.concatenate((self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
(self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
elif self.split == 'unlabeled': elif self.split == "unlabeled":
self.data, _ = self.__loadfile(self.train_list[2][0]) self.data, _ = self.__loadfile(self.train_list[2][0])
self.labels = np.asarray([-1] * self.data.shape[0]) self.labels = np.asarray([-1] * self.data.shape[0])
else: # self.split == 'test': else: # self.split == 'test':
self.data, self.labels = self.__loadfile( self.data, self.labels = self.__loadfile(self.test_list[0][0], self.test_list[1][0])
self.test_list[0][0], self.test_list[1][0])
class_file = os.path.join( class_file = os.path.join(self.root, self.base_folder, self.class_names_file)
self.root, self.base_folder, self.class_names_file)
if os.path.isfile(class_file): if os.path.isfile(class_file):
with open(class_file) as f: with open(class_file) as f:
self.classes = f.read().splitlines() self.classes = f.read().splitlines()
...@@ -101,8 +92,7 @@ class STL10(VisionDataset): ...@@ -101,8 +92,7 @@ class STL10(VisionDataset):
elif isinstance(folds, int): elif isinstance(folds, int):
if folds in range(10): if folds in range(10):
return folds return folds
msg = ("Value for argument folds should be in the range [0, 10), " msg = "Value for argument folds should be in the range [0, 10), " "but got {}."
"but got {}.")
raise ValueError(msg.format(folds)) raise ValueError(msg.format(folds))
else: else:
msg = "Expected type None or int for argument folds, but got type {}." msg = "Expected type None or int for argument folds, but got type {}."
...@@ -140,13 +130,12 @@ class STL10(VisionDataset): ...@@ -140,13 +130,12 @@ class STL10(VisionDataset):
def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]: def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
labels = None labels = None
if labels_file: if labels_file:
path_to_labels = os.path.join( path_to_labels = os.path.join(self.root, self.base_folder, labels_file)
self.root, self.base_folder, labels_file) with open(path_to_labels, "rb") as f:
with open(path_to_labels, 'rb') as f:
labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based
path_to_data = os.path.join(self.root, self.base_folder, data_file) path_to_data = os.path.join(self.root, self.base_folder, data_file)
with open(path_to_data, 'rb') as f: with open(path_to_data, "rb") as f:
# read whole file in uint8 chunks # read whole file in uint8 chunks
everything = np.fromfile(f, dtype=np.uint8) everything = np.fromfile(f, dtype=np.uint8)
images = np.reshape(everything, (-1, 3, 96, 96)) images = np.reshape(everything, (-1, 3, 96, 96))
...@@ -156,7 +145,7 @@ class STL10(VisionDataset): ...@@ -156,7 +145,7 @@ class STL10(VisionDataset):
def _check_integrity(self) -> bool: def _check_integrity(self) -> bool:
root = self.root root = self.root
for fentry in (self.train_list + self.test_list): for fentry in self.train_list + self.test_list:
filename, md5 = fentry[0], fentry[1] filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename) fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5): if not check_integrity(fpath, md5):
...@@ -165,7 +154,7 @@ class STL10(VisionDataset): ...@@ -165,7 +154,7 @@ class STL10(VisionDataset):
def download(self) -> None: def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print("Files already downloaded and verified")
return return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
self._check_integrity() self._check_integrity()
...@@ -177,11 +166,10 @@ class STL10(VisionDataset): ...@@ -177,11 +166,10 @@ class STL10(VisionDataset):
# loads one of the folds if specified # loads one of the folds if specified
if folds is None: if folds is None:
return return
path_to_folds = os.path.join( path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file)
self.root, self.base_folder, self.folds_list_file) with open(path_to_folds, "r") as f:
with open(path_to_folds, 'r') as f:
str_idx = f.read().splitlines()[folds] str_idx = f.read().splitlines()[folds]
list_idx = np.fromstring(str_idx, dtype=np.int64, sep=' ') list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ")
self.data = self.data[list_idx, :, :, :] self.data = self.data[list_idx, :, :, :]
if self.labels is not None: if self.labels is not None:
self.labels = self.labels[list_idx] self.labels = self.labels[list_idx]
from .vision import VisionDataset
from PIL import Image
import os import os
import os.path import os.path
import numpy as np
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
import numpy as np
from PIL import Image
from .utils import download_url, check_integrity, verify_str_arg from .utils import download_url, check_integrity, verify_str_arg
from .vision import VisionDataset
class SVHN(VisionDataset): class SVHN(VisionDataset):
...@@ -33,23 +35,32 @@ class SVHN(VisionDataset): ...@@ -33,23 +35,32 @@ class SVHN(VisionDataset):
""" """
split_list = { split_list = {
'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", "train": [
"train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], "http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", "train_32x32.mat",
"test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], "e26dedcc434d2e4c54c9b2d4a06d8373",
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", ],
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} "test": [
"http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
"test_32x32.mat",
"eb5a983be6a315427106f1b164d9cef3",
],
"extra": [
"http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat",
"a93ce644f1a588dc4d68dda5feec44a7",
],
}
def __init__( def __init__(
self, self,
root: str, root: str,
split: str = "train", split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(SVHN, self).__init__(root, transform=transform, super(SVHN, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform)
self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
self.url = self.split_list[split][0] self.url = self.split_list[split][0]
self.filename = self.split_list[split][1] self.filename = self.split_list[split][1]
...@@ -59,8 +70,7 @@ class SVHN(VisionDataset): ...@@ -59,8 +70,7 @@ class SVHN(VisionDataset):
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
' You can use download=True to download it')
# import here rather than at top of file because this is # import here rather than at top of file because this is
# an optional dependency for torchvision # an optional dependency for torchvision
...@@ -69,12 +79,12 @@ class SVHN(VisionDataset): ...@@ -69,12 +79,12 @@ class SVHN(VisionDataset):
# reading(loading) mat file as array # reading(loading) mat file as array
loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
self.data = loaded_mat['X'] self.data = loaded_mat["X"]
# loading from the .mat file gives an np array of type np.uint8 # loading from the .mat file gives an np array of type np.uint8
# converting to np.int64, so that we have a LongTensor after # converting to np.int64, so that we have a LongTensor after
# the conversion from the numpy array # the conversion from the numpy array
# the squeeze is needed to obtain a 1D tensor # the squeeze is needed to obtain a 1D tensor
self.labels = loaded_mat['y'].astype(np.int64).squeeze() self.labels = loaded_mat["y"].astype(np.int64).squeeze()
# the svhn dataset assigns the class label "10" to the digit 0 # the svhn dataset assigns the class label "10" to the digit 0
# this makes it inconsistent with several loss functions # this makes it inconsistent with several loss functions
......
import os import os
from typing import Any, Dict, List, Tuple, Optional, Callable from typing import Any, Dict, List, Tuple, Optional, Callable
from torch import Tensor from torch import Tensor
from .folder import find_classes, make_dataset from .folder import find_classes, make_dataset
...@@ -62,13 +63,13 @@ class UCF101(VisionDataset): ...@@ -62,13 +63,13 @@ class UCF101(VisionDataset):
_video_width: int = 0, _video_width: int = 0,
_video_height: int = 0, _video_height: int = 0,
_video_min_dimension: int = 0, _video_min_dimension: int = 0,
_audio_samples: int = 0 _audio_samples: int = 0,
) -> None: ) -> None:
super(UCF101, self).__init__(root) super(UCF101, self).__init__(root)
if not 1 <= fold <= 3: if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold)) raise ValueError("fold should be between 1 and 3, got {}".format(fold))
extensions = ('avi',) extensions = ("avi",)
self.fold = fold self.fold = fold
self.train = train self.train = train
......
from PIL import Image
import os import os
import numpy as np
from typing import Any, Callable, cast, Optional, Tuple from typing import Any, Callable, cast, Optional, Tuple
import numpy as np
from PIL import Image
from .utils import download_url from .utils import download_url
from .vision import VisionDataset from .vision import VisionDataset
...@@ -26,28 +27,30 @@ class USPS(VisionDataset): ...@@ -26,28 +27,30 @@ class USPS(VisionDataset):
downloaded again. downloaded again.
""" """
split_list = { split_list = {
'train': [ "train": [
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2", "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
"usps.bz2", 'ec16c51db3855ca6c91edd34d0e9b197' "usps.bz2",
"ec16c51db3855ca6c91edd34d0e9b197",
], ],
'test': [ "test": [
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2", "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
"usps.t.bz2", '8ea070ee2aca1ac39742fdd1ef5ed118' "usps.t.bz2",
"8ea070ee2aca1ac39742fdd1ef5ed118",
], ],
} }
def __init__( def __init__(
self, self,
root: str, root: str,
train: bool = True, train: bool = True,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(USPS, self).__init__(root, transform=transform, super(USPS, self).__init__(root, transform=transform, target_transform=target_transform)
target_transform=target_transform) split = "train" if train else "test"
split = 'train' if train else 'test'
url, filename, checksum = self.split_list[split] url, filename, checksum = self.split_list[split]
full_path = os.path.join(self.root, filename) full_path = os.path.join(self.root, filename)
...@@ -55,9 +58,10 @@ class USPS(VisionDataset): ...@@ -55,9 +58,10 @@ class USPS(VisionDataset):
download_url(url, self.root, filename, md5=checksum) download_url(url, self.root, filename, md5=checksum)
import bz2 import bz2
with bz2.open(full_path) as fp: with bz2.open(full_path) as fp:
raw_data = [line.decode().split() for line in fp.readlines()] raw_data = [line.decode().split() for line in fp.readlines()]
tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data] tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16)) imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8) imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data] targets = [int(d[0]) - 1 for d in raw_data]
...@@ -77,7 +81,7 @@ class USPS(VisionDataset): ...@@ -77,7 +81,7 @@ class USPS(VisionDataset):
# doing this so that it is consistent with all other datasets # doing this so that it is consistent with all other datasets
# to return a PIL Image # to return a PIL Image
img = Image.fromarray(img, mode='L') img = Image.fromarray(img, mode="L")
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment