Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
from .vision import VisionDataset
from PIL import Image
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from PIL import Image
from .vision import VisionDataset
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
......@@ -132,16 +132,15 @@ class DatasetFolder(VisionDataset):
"""
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
......@@ -186,9 +185,7 @@ class DatasetFolder(VisionDataset):
# prevent potential bug since make_dataset() would use the class_to_idx logic of the
# find_classes() function, instead of using that of the find_classes() method, which
# is potentially overridden and thus could have a different logic.
raise ValueError(
"The class_to_idx parameter cannot be None."
)
raise ValueError("The class_to_idx parameter cannot be None.")
return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
......@@ -241,19 +238,20 @@ class DatasetFolder(VisionDataset):
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
with open(path, "rb") as f:
img = Image.open(f)
return img.convert('RGB')
return img.convert("RGB")
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
import accimage
try:
return accimage.Image(path)
except IOError:
......@@ -263,7 +261,8 @@ def accimage_loader(path: str) -> Any:
def default_loader(path: str) -> Any:
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
if get_image_backend() == "accimage":
return accimage_loader(path)
else:
return pil_loader(path)
......@@ -300,15 +299,19 @@ class ImageFolder(DatasetFolder):
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file)
super(ImageFolder, self).__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
import glob
import os
from typing import Optional, Callable, Tuple, Dict, Any, List
from torch import Tensor
from .folder import find_classes, make_dataset
......@@ -49,7 +50,7 @@ class HMDB51(VisionDataset):
data_url = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"
splits = {
"url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
"md5": "15e67781e70dcfbdce2d7dbb9b3344b5"
"md5": "15e67781e70dcfbdce2d7dbb9b3344b5",
}
TRAIN_TAG = 1
TEST_TAG = 2
......@@ -75,7 +76,7 @@ class HMDB51(VisionDataset):
if fold not in (1, 2, 3):
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
extensions = ('avi',)
extensions = ("avi",)
self.classes, class_to_idx = find_classes(self.root)
self.samples = make_dataset(
self.root,
......
import warnings
from contextlib import contextmanager
import os
import shutil
import tempfile
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Iterator, Optional, Tuple
import torch
from .folder import ImageFolder
from .utils import check_integrity, extract_archive, verify_str_arg
ARCHIVE_META = {
'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'),
'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'),
'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf')
"train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
"val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
"devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
}
META_FILE = "meta.bin"
......@@ -38,15 +40,16 @@ class ImageNet(ImageFolder):
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root: str, split: str = 'train', download: Optional[str] = None, **kwargs: Any) -> None:
def __init__(self, root: str, split: str = "train", download: Optional[str] = None, **kwargs: Any) -> None:
if download is True:
msg = ("The dataset is no longer publicly accessible. You need to "
"download the archives externally and place them in the root "
"directory.")
msg = (
"The dataset is no longer publicly accessible. You need to "
"download the archives externally and place them in the root "
"directory."
)
raise RuntimeError(msg)
elif download is False:
msg = ("The use of the download flag is deprecated, since the dataset "
"is no longer publicly accessible.")
msg = "The use of the download flag is deprecated, since the dataset " "is no longer publicly accessible."
warnings.warn(msg, RuntimeWarning)
root = self.root = os.path.expanduser(root)
......@@ -61,18 +64,16 @@ class ImageNet(ImageFolder):
self.wnids = self.classes
self.wnid_to_idx = self.class_to_idx
self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
self.class_to_idx = {cls: idx
for idx, clss in enumerate(self.classes)
for cls in clss}
self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss}
def parse_archives(self) -> None:
if not check_integrity(os.path.join(self.root, META_FILE)):
parse_devkit_archive(self.root)
if not os.path.isdir(self.split_folder):
if self.split == 'train':
if self.split == "train":
parse_train_archive(self.root)
elif self.split == 'val':
elif self.split == "val":
parse_val_archive(self.root)
@property
......@@ -91,15 +92,19 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str
if check_integrity(file):
return torch.load(file)
else:
msg = ("The meta file {} is not present in the root directory or is corrupted. "
"This file is automatically created by the ImageNet dataset.")
msg = (
"The meta file {} is not present in the root directory or is corrupted. "
"This file is automatically created by the ImageNet dataset."
)
raise RuntimeError(msg.format(file, root))
def _verify_archive(root: str, file: str, md5: str) -> None:
if not check_integrity(os.path.join(root, file), md5):
msg = ("The archive {} is not present in the root directory or is corrupted. "
"You need to download it externally and place it in {}.")
msg = (
"The archive {} is not present in the root directory or is corrupted. "
"You need to download it externally and place it in {}."
)
raise RuntimeError(msg.format(file, root))
......@@ -116,20 +121,18 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]:
metafile = os.path.join(devkit_root, "data", "meta.mat")
meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
meta = sio.loadmat(metafile, squeeze_me=True)["synsets"]
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children)
if num_children == 0]
meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
idcs, wnids, classes = list(zip(*meta))[:3]
classes = [tuple(clss.split(', ')) for clss in classes]
classes = [tuple(clss.split(", ")) for clss in classes]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
return idx_to_wnid, wnid_to_classes
def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
file = os.path.join(devkit_root, "data",
"ILSVRC2012_validation_ground_truth.txt")
with open(file, 'r') as txtfh:
file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
with open(file, "r") as txtfh:
val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs]
......
from PIL import Image
import os
import os.path
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from .vision import VisionDataset
from PIL import Image
from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]
DATASET_URLS = {
'2017': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz',
'2018': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz',
'2019': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz',
'2021_train': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz',
'2021_train_mini': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz',
'2021_valid': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz',
"2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
"2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
"2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
"2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
"2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
"2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
}
DATASET_MD5 = {
'2017': '7c784ea5e424efaec655bd392f87301f',
'2018': 'b1c6952ce38f31868cc50ea72d066cc3',
'2019': 'c60a6e2962c9b8ccbd458d12c8582644',
'2021_train': '38a7bb733f7a09214d44293460ec0021',
'2021_train_mini': 'db6ed8330e634445efc8fec83ae81442',
'2021_valid': 'f6f6e0e242e3d4c9569ba56400938afc',
"2017": "7c784ea5e424efaec655bd392f87301f",
"2018": "b1c6952ce38f31868cc50ea72d066cc3",
"2019": "c60a6e2962c9b8ccbd458d12c8582644",
"2021_train": "38a7bb733f7a09214d44293460ec0021",
"2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
"2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
}
......@@ -63,27 +64,26 @@ class INaturalist(VisionDataset):
"""
def __init__(
self,
root: str,
version: str = "2021_train",
target_type: Union[List[str], str] = "full",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
self,
root: str,
version: str = "2021_train",
target_type: Union[List[str], str] = "full",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
super(INaturalist, self).__init__(os.path.join(root, version),
transform=transform,
target_transform=target_transform)
super(INaturalist, self).__init__(
os.path.join(root, version), transform=transform, target_transform=target_transform
)
os.makedirs(root, exist_ok=True)
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
self.all_categories: List[str] = []
......@@ -96,12 +96,10 @@ class INaturalist(VisionDataset):
if not isinstance(target_type, list):
target_type = [target_type]
if self.version[:4] == "2021":
self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021))
for t in target_type]
self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
self._init_2021()
else:
self.target_type = [verify_str_arg(t, "target_type", ("full", "super"))
for t in target_type]
self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
self._init_pre2021()
# index of all files: (full category id, filename)
......@@ -118,16 +116,14 @@ class INaturalist(VisionDataset):
self.all_categories = sorted(os.listdir(self.root))
# map: category type -> name of category -> index
self.categories_index = {
k: {} for k in CATEGORIES_2021
}
self.categories_index = {k: {} for k in CATEGORIES_2021}
for dir_index, dir_name in enumerate(self.all_categories):
pieces = dir_name.split('_')
pieces = dir_name.split("_")
if len(pieces) != 8:
raise RuntimeError(f'Unexpected category name {dir_name}, wrong number of pieces')
if pieces[0] != f'{dir_index:05d}':
raise RuntimeError(f'Unexpected category id {pieces[0]}, expecting {dir_index:05d}')
raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
if pieces[0] != f"{dir_index:05d}":
raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
cat_map = {}
for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
if name in self.categories_index[cat]:
......@@ -142,7 +138,7 @@ class INaturalist(VisionDataset):
"""Initialize based on 2017-2019 layout"""
# map: category type -> name of category -> index
self.categories_index = {'super': {}}
self.categories_index = {"super": {}}
cat_index = 0
super_categories = sorted(os.listdir(self.root))
......@@ -165,7 +161,7 @@ class INaturalist(VisionDataset):
self.all_categories.extend([""] * (subcat_i - old_len + 1))
if self.categories_map[subcat_i]:
raise RuntimeError(f"Duplicate category {subcat}")
self.categories_map[subcat_i] = {'super': sindex}
self.categories_map[subcat_i] = {"super": sindex}
self.all_categories[subcat_i] = os.path.join(scat, subcat)
# validate the dictionary
......@@ -183,9 +179,7 @@ class INaturalist(VisionDataset):
"""
cat_id, fname = self.index[index]
img = Image.open(os.path.join(self.root,
self.all_categories[cat_id],
fname))
img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
target: Any = []
for t in self.target_type:
......@@ -239,10 +233,8 @@ class INaturalist(VisionDataset):
base_root = os.path.dirname(self.root)
download_and_extract_archive(
DATASET_URLS[self.version],
base_root,
filename=f"{self.version}.tgz",
md5=DATASET_MD5[self.version])
DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
)
orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
if not os.path.exists(orig_dir_name):
......
import time
import csv
import os
import time
import warnings
from os import path
import csv
from typing import Any, Callable, Dict, Optional, Tuple
from functools import partial
from multiprocessing import Pool
from os import path
from typing import Any, Callable, Dict, Optional, Tuple
from torch import Tensor
from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity
from .folder import find_classes, make_dataset
from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity
from .video_utils import VideoClips
from .vision import VisionDataset
......@@ -214,18 +213,13 @@ class Kinetics(VisionDataset):
start=int(row["time_start"]),
end=int(row["time_end"]),
)
label = (
row["label"]
.replace(" ", "_")
.replace("'", "")
.replace("(", "")
.replace(")", "")
)
label = row["label"].replace(" ", "_").replace("'", "").replace("(", "").replace(")", "")
os.makedirs(path.join(self.split_folder, label), exist_ok=True)
downloaded_file = path.join(self.split_folder, f)
if path.isfile(downloaded_file):
os.replace(
downloaded_file, path.join(self.split_folder, label, f),
downloaded_file,
path.join(self.split_folder, label, f),
)
@property
......@@ -303,11 +297,12 @@ class Kinetics400(Kinetics):
split: Any = None,
download: Any = None,
num_download_workers: Any = None,
**kwargs: Any
**kwargs: Any,
) -> None:
warnings.warn(
"Kinetics400 is deprecated and will be removed in a future release."
"It was replaced by Kinetics(..., num_classes=\"400\").")
'It was replaced by Kinetics(..., num_classes="400").'
)
if any(value is not None for value in (num_classes, split, download, num_download_workers)):
raise RuntimeError(
"Usage of 'num_classes', 'split', 'download', or 'num_download_workers' is not supported in "
......
......@@ -73,9 +73,7 @@ class Kitti(VisionDataset):
if download:
self.download()
if not self._check_exists():
raise RuntimeError(
"Dataset not found. You may use download=True to download it."
)
raise RuntimeError("Dataset not found. You may use download=True to download it.")
image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name)
if self.train:
......@@ -83,9 +81,7 @@ class Kitti(VisionDataset):
for img_file in os.listdir(image_dir):
self.images.append(os.path.join(image_dir, img_file))
if self.train:
self.targets.append(
os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt")
)
self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt"))
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Get item at a given index.
......@@ -117,16 +113,18 @@ class Kitti(VisionDataset):
with open(self.targets[index]) as inp:
content = csv.reader(inp, delimiter=" ")
for line in content:
target.append({
"type": line[0],
"truncated": float(line[1]),
"occluded": int(line[2]),
"alpha": float(line[3]),
"bbox": [float(x) for x in line[4:8]],
"dimensions": [float(x) for x in line[8:11]],
"location": [float(x) for x in line[11:14]],
"rotation_y": float(line[14]),
})
target.append(
{
"type": line[0],
"truncated": float(line[1]),
"occluded": int(line[2]),
"alpha": float(line[3]),
"bbox": [float(x) for x in line[4:8]],
"dimensions": [float(x) for x in line[8:11]],
"location": [float(x) for x in line[11:14]],
"rotation_y": float(line[14]),
}
)
return target
def __len__(self) -> int:
......@@ -141,10 +139,7 @@ class Kitti(VisionDataset):
folders = [self.image_dir_name]
if self.train:
folders.append(self.labels_dir_name)
return all(
os.path.isdir(os.path.join(self._raw_folder, self._location, fname))
for fname in folders
)
return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders)
def download(self) -> None:
"""Download the KITTI data if it doesn't exist already."""
......
import os
from typing import Any, Callable, List, Optional, Tuple
from PIL import Image
from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
from .vision import VisionDataset
class _LFW(VisionDataset):
base_folder = 'lfw-py'
base_folder = "lfw-py"
download_url_prefix = "http://vis-www.cs.umass.edu/lfw/"
file_dict = {
'original': ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"),
'funneled': ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"),
'deepfunneled': ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201")
"original": ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"),
"funneled": ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"),
"deepfunneled": ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201"),
}
checksums = {
'pairs.txt': '9f1ba174e4e1c508ff7cdf10ac338a7d',
'pairsDevTest.txt': '5132f7440eb68cf58910c8a45a2ac10b',
'pairsDevTrain.txt': '4f27cbf15b2da4a85c1907eb4181ad21',
'people.txt': '450f0863dd89e85e73936a6d71a3474b',
'peopleDevTest.txt': 'e4bf5be0a43b5dcd9dc5ccfcb8fb19c5',
'peopleDevTrain.txt': '54eaac34beb6d042ed3a7d883e247a21',
'lfw-names.txt': 'a6d0a479bd074669f656265a6e693f6d'
"pairs.txt": "9f1ba174e4e1c508ff7cdf10ac338a7d",
"pairsDevTest.txt": "5132f7440eb68cf58910c8a45a2ac10b",
"pairsDevTrain.txt": "4f27cbf15b2da4a85c1907eb4181ad21",
"people.txt": "450f0863dd89e85e73936a6d71a3474b",
"peopleDevTest.txt": "e4bf5be0a43b5dcd9dc5ccfcb8fb19c5",
"peopleDevTrain.txt": "54eaac34beb6d042ed3a7d883e247a21",
"lfw-names.txt": "a6d0a479bd074669f656265a6e693f6d",
}
annot_file = {'10fold': '', 'train': 'DevTrain', 'test': 'DevTest'}
annot_file = {"10fold": "", "train": "DevTrain", "test": "DevTest"}
names = "lfw-names.txt"
def __init__(
......@@ -37,14 +39,15 @@ class _LFW(VisionDataset):
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(_LFW, self).__init__(os.path.join(root, self.base_folder),
transform=transform, target_transform=target_transform)
super(_LFW, self).__init__(
os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform
)
self.image_set = verify_str_arg(image_set.lower(), 'image_set', self.file_dict.keys())
self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
images_dir, self.filename, self.md5 = self.file_dict[self.image_set]
self.view = verify_str_arg(view.lower(), 'view', ['people', 'pairs'])
self.split = verify_str_arg(split.lower(), 'split', ['10fold', 'train', 'test'])
self.view = verify_str_arg(view.lower(), "view", ["people", "pairs"])
self.split = verify_str_arg(split.lower(), "split", ["10fold", "train", "test"])
self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt"
self.data: List[Any] = []
......@@ -52,15 +55,14 @@ class _LFW(VisionDataset):
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
self.images_dir = os.path.join(self.root, images_dir)
def _loader(self, path: str) -> Image.Image:
with open(path, 'rb') as f:
with open(path, "rb") as f:
img = Image.open(f)
return img.convert('RGB')
return img.convert("RGB")
def _check_integrity(self):
st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
......@@ -73,7 +75,7 @@ class _LFW(VisionDataset):
def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
print("Files already downloaded and verified")
return
url = f"{self.download_url_prefix}{self.filename}"
download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5)
......@@ -120,21 +122,20 @@ class LFWPeople(_LFW):
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(LFWPeople, self).__init__(root, split, image_set, "people",
transform, target_transform, download)
super(LFWPeople, self).__init__(root, split, image_set, "people", transform, target_transform, download)
self.class_to_idx = self._get_classes()
self.data, self.targets = self._get_people()
def _get_people(self):
data, targets = [], []
with open(os.path.join(self.root, self.labels_file), 'r') as f:
with open(os.path.join(self.root, self.labels_file), "r") as f:
lines = f.readlines()
n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)
for fold in range(n_folds):
n_lines = int(lines[s])
people = [line.strip().split("\t") for line in lines[s + 1: s + n_lines + 1]]
people = [line.strip().split("\t") for line in lines[s + 1 : s + n_lines + 1]]
s += n_lines + 1
for i, (identity, num_imgs) in enumerate(people):
for num in range(1, int(num_imgs) + 1):
......@@ -145,7 +146,7 @@ class LFWPeople(_LFW):
return data, targets
def _get_classes(self):
with open(os.path.join(self.root, self.names), 'r') as f:
with open(os.path.join(self.root, self.names), "r") as f:
lines = f.readlines()
names = [line.strip().split()[0] for line in lines]
class_to_idx = {name: i for i, name in enumerate(names)}
......@@ -203,14 +204,13 @@ class LFWPairs(_LFW):
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(LFWPairs, self).__init__(root, split, image_set, "pairs",
transform, target_transform, download)
super(LFWPairs, self).__init__(root, split, image_set, "pairs", transform, target_transform, download)
self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
def _get_pairs(self, images_dir):
pair_names, data, targets = [], [], []
with open(os.path.join(self.root, self.labels_file), 'r') as f:
with open(os.path.join(self.root, self.labels_file), "r") as f:
lines = f.readlines()
if self.split == "10fold":
n_folds, n_pairs = lines[0].split("\t")
......@@ -220,9 +220,9 @@ class LFWPairs(_LFW):
s = 1
for fold in range(n_folds):
matched_pairs = [line.strip().split("\t") for line in lines[s: s + n_pairs]]
unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs: s + (2 * n_pairs)]]
s += (2 * n_pairs)
matched_pairs = [line.strip().split("\t") for line in lines[s : s + n_pairs]]
unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs : s + (2 * n_pairs)]]
s += 2 * n_pairs
for pair in matched_pairs:
img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1
pair_names.append((pair[0], pair[0]))
......
from .vision import VisionDataset
from PIL import Image
import io
import os
import os.path
import io
import pickle
import string
from collections.abc import Iterable
import pickle
from typing import Any, Callable, cast, List, Optional, Tuple, Union
from PIL import Image
from .utils import verify_str_arg, iterable_to_str
from .vision import VisionDataset
class LSUNClass(VisionDataset):
def __init__(
self, root: str, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None
self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
) -> None:
import lmdb
super(LSUNClass, self).__init__(root, transform=transform,
target_transform=target_transform)
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
readahead=False, meminit=False)
super(LSUNClass, self).__init__(root, transform=transform, target_transform=target_transform)
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self.length = txn.stat()['entries']
cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters)
self.length = txn.stat()["entries"]
cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters)
if os.path.isfile(cache_file):
self.keys = pickle.load(open(cache_file, "rb"))
else:
......@@ -40,7 +40,7 @@ class LSUNClass(VisionDataset):
buf = io.BytesIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf).convert('RGB')
img = Image.open(buf).convert("RGB")
if self.transform is not None:
img = self.transform(img)
......@@ -71,22 +71,19 @@ class LSUN(VisionDataset):
"""
def __init__(
self,
root: str,
classes: Union[str, List[str]] = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
self,
root: str,
classes: Union[str, List[str]] = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super(LSUN, self).__init__(root, transform=transform,
target_transform=target_transform)
super(LSUN, self).__init__(root, transform=transform, target_transform=target_transform)
self.classes = self._verify_classes(classes)
# for each class, create an LSUNClassDataset
self.dbs = []
for c in self.classes:
self.dbs.append(LSUNClass(
root=os.path.join(root, f"{c}_lmdb"),
transform=transform))
self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform))
self.indices = []
count = 0
......@@ -97,35 +94,41 @@ class LSUN(VisionDataset):
self.length = count
def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]:
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test']
categories = [
"bedroom",
"bridge",
"church_outdoor",
"classroom",
"conference_room",
"dining_room",
"kitchen",
"living_room",
"restaurant",
"tower",
]
dset_opts = ["train", "val", "test"]
try:
classes = cast(str, classes)
verify_str_arg(classes, "classes", dset_opts)
if classes == 'test':
if classes == "test":
classes = [classes]
else:
classes = [c + '_' + classes for c in categories]
classes = [c + "_" + classes for c in categories]
except ValueError:
if not isinstance(classes, Iterable):
msg = ("Expected type str or Iterable for argument classes, "
"but got type {}.")
msg = "Expected type str or Iterable for argument classes, " "but got type {}."
raise ValueError(msg.format(type(classes)))
classes = list(classes)
msg_fmtstr_type = ("Expected type str for elements in argument classes, "
"but got type {}.")
msg_fmtstr_type = "Expected type str for elements in argument classes, " "but got type {}."
for c in classes:
verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
c_short = c.split('_')
category, dset_opt = '_'.join(c_short[:-1]), c_short[-1]
c_short = c.split("_")
category, dset_opt = "_".join(c_short[:-1]), c_short[-1]
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
msg = msg_fmtstr.format(category, "LSUN class",
iterable_to_str(categories))
msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories))
verify_str_arg(category, valid_values=categories, custom_msg=msg)
msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
......
from .vision import VisionDataset
import warnings
from PIL import Image
import codecs
import os
import os.path
import numpy as np
import torch
import codecs
import shutil
import string
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError
import numpy as np
import torch
from PIL import Image
from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity
import shutil
from .vision import VisionDataset
class MNIST(VisionDataset):
......@@ -31,21 +33,31 @@ class MNIST(VisionDataset):
"""
mirrors = [
'http://yann.lecun.com/exdb/mnist/',
'https://ossci-datasets.s3.amazonaws.com/mnist/',
"http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/",
]
resources = [
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
]
training_file = 'training.pt'
test_file = 'test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
training_file = "training.pt"
test_file = "test.pt"
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
@property
def train_labels(self):
......@@ -68,15 +80,14 @@ class MNIST(VisionDataset):
return self.data
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(MNIST, self).__init__(root, transform=transform,
target_transform=target_transform)
super(MNIST, self).__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set
if self._check_legacy_exist():
......@@ -87,8 +98,7 @@ class MNIST(VisionDataset):
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
raise RuntimeError("Dataset not found." + " You can use download=True to download it")
self.data, self.targets = self._load_data()
......@@ -128,7 +138,7 @@ class MNIST(VisionDataset):
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
......@@ -143,11 +153,11 @@ class MNIST(VisionDataset):
@property
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'raw')
return os.path.join(self.root, self.__class__.__name__, "raw")
@property
def processed_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'processed')
return os.path.join(self.root, self.__class__.__name__, "processed")
@property
def class_to_idx(self) -> Dict[str, int]:
......@@ -173,15 +183,9 @@ class MNIST(VisionDataset):
url = "{}{}".format(mirror, filename)
try:
print("Downloading {}".format(url))
download_and_extract_archive(
url, download_root=self.raw_folder,
filename=filename,
md5=md5
)
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
except URLError as error:
print(
"Failed to download (trying next):\n{}".format(error)
)
print("Failed to download (trying next):\n{}".format(error))
continue
finally:
print()
......@@ -209,18 +213,16 @@ class FashionMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
mirrors = [
"http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"
]
mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
resources = [
("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310")
("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
]
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
class KMNIST(MNIST):
......@@ -239,17 +241,16 @@ class KMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
mirrors = [
"http://codh.rois.ac.jp/kmnist/dataset/kmnist/"
]
mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
resources = [
("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134")
("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
]
classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo']
classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
class EMNIST(MNIST):
......@@ -271,19 +272,20 @@ class EMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
url = 'https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"
md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
# Merged Classes assumes Same structure for both uppercase and lowercase version
_merged_classes = {'c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z'}
_merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
_all_classes = set(string.digits + string.ascii_letters)
classes_split_dict = {
'byclass': sorted(list(_all_classes)),
'bymerge': sorted(list(_all_classes - _merged_classes)),
'balanced': sorted(list(_all_classes - _merged_classes)),
'letters': ['N/A'] + list(string.ascii_lowercase),
'digits': list(string.digits),
'mnist': list(string.digits),
"byclass": sorted(list(_all_classes)),
"bymerge": sorted(list(_all_classes - _merged_classes)),
"balanced": sorted(list(_all_classes - _merged_classes)),
"letters": ["N/A"] + list(string.ascii_lowercase),
"digits": list(string.digits),
"mnist": list(string.digits),
}
def __init__(self, root: str, split: str, **kwargs: Any) -> None:
......@@ -295,11 +297,11 @@ class EMNIST(MNIST):
@staticmethod
def _training_file(split) -> str:
return 'training_{}.pt'.format(split)
return "training_{}.pt".format(split)
@staticmethod
def _test_file(split) -> str:
return 'test_{}.pt'.format(split)
return "test_{}.pt".format(split)
@property
def _file_prefix(self) -> str:
......@@ -328,9 +330,9 @@ class EMNIST(MNIST):
os.makedirs(self.raw_folder, exist_ok=True)
download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
gzip_folder = os.path.join(self.raw_folder, 'gzip')
gzip_folder = os.path.join(self.raw_folder, "gzip")
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'):
if gzip_file.endswith(".gz"):
extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
shutil.rmtree(gzip_folder)
......@@ -365,39 +367,60 @@ class QMNIST(MNIST):
training set ot the testing set. Default: True.
"""
subsets = {
'train': 'train',
'test': 'test',
'test10k': 'test',
'test50k': 'test',
'nist': 'nist'
}
subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz',
'ed72d4157d28c017586c42bc6afe6370'),
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz',
'0058f8dd561b90ffdd0f734c6a30e5e4')],
'test': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz',
'1394631089c404de565df7b7aeaf9412'),
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz',
'5b5b05890a5e13444e108efe57b788aa')],
'nist': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz',
'7f124b3b8ab81486c9d8c2749c17f834'),
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz',
'5ed0e788978e45d4a8bd4b7caec3d79d')]
"train": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
"ed72d4157d28c017586c42bc6afe6370",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
"0058f8dd561b90ffdd0f734c6a30e5e4",
),
],
"test": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
"1394631089c404de565df7b7aeaf9412",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
"5b5b05890a5e13444e108efe57b788aa",
),
],
"nist": [
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
"7f124b3b8ab81486c9d8c2749c17f834",
),
(
"https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
"5ed0e788978e45d4a8bd4b7caec3d79d",
),
],
}
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
def __init__(
self, root: str, what: Optional[str] = None, compat: bool = True,
train: bool = True, **kwargs: Any
self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
) -> None:
if what is None:
what = 'train' if train else 'test'
what = "train" if train else "test"
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
self.compat = compat
self.data_file = what + '.pt'
self.data_file = what + ".pt"
self.training_file = self.data_file
self.test_file = self.data_file
super(QMNIST, self).__init__(root, train, **kwargs)
......@@ -417,16 +440,16 @@ class QMNIST(MNIST):
def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file)
assert (data.dtype == torch.uint8)
assert (data.ndimension() == 3)
assert data.dtype == torch.uint8
assert data.ndimension() == 3
targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
assert (targets.ndimension() == 2)
assert targets.ndimension() == 2
if self.what == 'test10k':
if self.what == "test10k":
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
elif self.what == 'test50k':
elif self.what == "test50k":
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()
......@@ -434,7 +457,7 @@ class QMNIST(MNIST):
def download(self) -> None:
"""Download the QMNIST data if it doesn't exist already.
Note that we only download what has been asked for (argument 'what').
Note that we only download what has been asked for (argument 'what').
"""
if self._check_exists():
return
......@@ -443,7 +466,7 @@ class QMNIST(MNIST):
split = self.resources[self.subsets[self.what]]
for url, md5 in split:
filename = url.rpartition('/')[2]
filename = url.rpartition("/")[2]
file_path = os.path.join(self.raw_folder, filename)
if not os.path.isfile(file_path):
download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5)
......@@ -451,7 +474,7 @@ class QMNIST(MNIST):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img.numpy(), mode='L')
img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
if self.compat:
......@@ -465,22 +488,22 @@ class QMNIST(MNIST):
def get_int(b: bytes) -> int:
return int(codecs.encode(b, 'hex'), 16)
return int(codecs.encode(b, "hex"), 16)
SN3_PASCALVINCENT_TYPEMAP = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype('>i2'), 'i2'),
12: (torch.int32, np.dtype('>i4'), 'i4'),
13: (torch.float32, np.dtype('>f4'), 'f4'),
14: (torch.float64, np.dtype('>f8'), 'f8')
11: (torch.int16, np.dtype(">i2"), "i2"),
12: (torch.int32, np.dtype(">i4"), "i4"),
13: (torch.float32, np.dtype(">f4"), "f4"),
14: (torch.float64, np.dtype(">f8"), "f8"),
}
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
Argument may be a filename, compressed filename, or file object.
"""
# read
with open(path, "rb") as f:
......@@ -492,7 +515,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
assert 1 <= nd <= 3
assert 8 <= ty <= 14
m = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
assert parsed.shape[0] == np.prod(s) or not strict
return torch.from_numpy(parsed.astype(m[2])).view(*s)
......@@ -500,13 +523,13 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
def read_label_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert(x.dtype == torch.uint8)
assert(x.ndimension() == 1)
assert x.dtype == torch.uint8
assert x.ndimension() == 1
return x.long()
def read_image_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert(x.dtype == torch.uint8)
assert(x.ndimension() == 3)
assert x.dtype == torch.uint8
assert x.ndimension() == 3
return x
from PIL import Image
from os.path import join
from typing import Any, Callable, List, Optional, Tuple
from .vision import VisionDataset
from PIL import Image
from .utils import download_and_extract_archive, check_integrity, list_dir, list_files
from .vision import VisionDataset
class Omniglot(VisionDataset):
......@@ -21,38 +23,40 @@ class Omniglot(VisionDataset):
puts it in root directory. If the zip files are already downloaded, they are not
downloaded again.
"""
folder = 'omniglot-py'
download_url_prefix = 'https://raw.githubusercontent.com/brendenlake/omniglot/master/python'
folder = "omniglot-py"
download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
zips_md5 = {
'images_background': '68d2efa1b9178cc56df9314c21c6e718',
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
"images_background": "68d2efa1b9178cc56df9314c21c6e718",
"images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
}
def __init__(
self,
root: str,
background: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
self,
root: str,
background: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Omniglot, self).__init__(join(root, self.folder), transform=transform,
target_transform=target_transform)
super(Omniglot, self).__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
self.background = background
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
self.target_folder = join(self.root, self._get_target_folder())
self._alphabets = list_dir(self.target_folder)
self._characters: List[str] = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
for a in self._alphabets], [])
self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')]
for idx, character in enumerate(self._characters)]
self._characters: List[str] = sum(
[[join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets], []
)
self._character_images = [
[(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
for idx, character in enumerate(self._characters)
]
self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
def __len__(self) -> int:
......@@ -68,7 +72,7 @@ class Omniglot(VisionDataset):
"""
image_name, character_class = self._flat_character_images[index]
image_path = join(self.target_folder, self._characters[character_class], image_name)
image = Image.open(image_path, mode='r').convert('L')
image = Image.open(image_path, mode="r").convert("L")
if self.transform:
image = self.transform(image)
......@@ -80,19 +84,19 @@ class Omniglot(VisionDataset):
def _check_integrity(self) -> bool:
zip_filename = self._get_target_folder()
if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
return False
return True
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
print("Files already downloaded and verified")
return
filename = self._get_target_folder()
zip_filename = filename + '.zip'
url = self.download_url_prefix + '/' + zip_filename
zip_filename = filename + ".zip"
url = self.download_url_prefix + "/" + zip_filename
download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
def _get_target_folder(self) -> str:
return 'images_background' if self.background else 'images_evaluation'
return "images_background" if self.background else "images_evaluation"
import os
import numpy as np
from PIL import Image
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch
from .vision import VisionDataset
from PIL import Image
from .utils import download_url
from .vision import VisionDataset
class PhotoTour(VisionDataset):
......@@ -33,56 +33,67 @@ class PhotoTour(VisionDataset):
downloaded again.
"""
urls = {
'notredame_harris': [
'http://matthewalunbrown.com/patchdata/notredame_harris.zip',
'notredame_harris.zip',
'69f8c90f78e171349abdf0307afefe4d'
],
'yosemite_harris': [
'http://matthewalunbrown.com/patchdata/yosemite_harris.zip',
'yosemite_harris.zip',
'a73253d1c6fbd3ba2613c45065c00d46'
"notredame_harris": [
"http://matthewalunbrown.com/patchdata/notredame_harris.zip",
"notredame_harris.zip",
"69f8c90f78e171349abdf0307afefe4d",
],
'liberty_harris': [
'http://matthewalunbrown.com/patchdata/liberty_harris.zip',
'liberty_harris.zip',
'c731fcfb3abb4091110d0ae8c7ba182c'
"yosemite_harris": [
"http://matthewalunbrown.com/patchdata/yosemite_harris.zip",
"yosemite_harris.zip",
"a73253d1c6fbd3ba2613c45065c00d46",
],
'notredame': [
'http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip',
'notredame.zip',
'509eda8535847b8c0a90bbb210c83484'
"liberty_harris": [
"http://matthewalunbrown.com/patchdata/liberty_harris.zip",
"liberty_harris.zip",
"c731fcfb3abb4091110d0ae8c7ba182c",
],
'yosemite': [
'http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip',
'yosemite.zip',
'533b2e8eb7ede31be40abc317b2fd4f0'
],
'liberty': [
'http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip',
'liberty.zip',
'fdd9152f138ea5ef2091746689176414'
"notredame": [
"http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip",
"notredame.zip",
"509eda8535847b8c0a90bbb210c83484",
],
"yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"],
"liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"],
}
means = {
"notredame": 0.4854,
"yosemite": 0.4844,
"liberty": 0.4437,
"notredame_harris": 0.4854,
"yosemite_harris": 0.4844,
"liberty_harris": 0.4437,
}
stds = {
"notredame": 0.1864,
"yosemite": 0.1818,
"liberty": 0.2019,
"notredame_harris": 0.1864,
"yosemite_harris": 0.1818,
"liberty_harris": 0.2019,
}
lens = {
"notredame": 468159,
"yosemite": 633587,
"liberty": 450092,
"liberty_harris": 379587,
"yosemite_harris": 450912,
"notredame_harris": 325295,
}
means = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437,
'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437}
stds = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019,
'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019}
lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092,
'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295}
image_ext = 'bmp'
info_file = 'info.txt'
matches_files = 'm50_100000_100000_0.txt'
image_ext = "bmp"
info_file = "info.txt"
matches_files = "m50_100000_100000_0.txt"
def __init__(
self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
) -> None:
super(PhotoTour, self).__init__(root, transform=transform)
self.name = name
self.data_dir = os.path.join(self.root, name)
self.data_down = os.path.join(self.root, '{}.zip'.format(name))
self.data_file = os.path.join(self.root, '{}.pt'.format(name))
self.data_down = os.path.join(self.root, "{}.zip".format(name))
self.data_file = os.path.join(self.root, "{}.pt".format(name))
self.train = train
self.mean = self.means[name]
......@@ -128,7 +139,7 @@ class PhotoTour(VisionDataset):
def download(self) -> None:
if self._check_datafile_exists():
print('# Found cached data {}'.format(self.data_file))
print("# Found cached data {}".format(self.data_file))
return
if not self._check_downloaded():
......@@ -140,25 +151,26 @@ class PhotoTour(VisionDataset):
download_url(url, self.root, filename, md5)
print('# Extracting data {}\n'.format(self.data_down))
print("# Extracting data {}\n".format(self.data_down))
import zipfile
with zipfile.ZipFile(fpath, 'r') as z:
with zipfile.ZipFile(fpath, "r") as z:
z.extractall(self.data_dir)
os.unlink(fpath)
def cache(self) -> None:
# process and save as torch files
print('# Caching data {}'.format(self.data_file))
print("# Caching data {}".format(self.data_file))
dataset = (
read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
read_info_file(self.data_dir, self.info_file),
read_matches_files(self.data_dir, self.matches_files)
read_matches_files(self.data_dir, self.matches_files),
)
with open(self.data_file, 'wb') as f:
with open(self.data_file, "wb") as f:
torch.save(dataset, f)
def extra_repr(self) -> str:
......@@ -166,17 +178,14 @@ class PhotoTour(VisionDataset):
def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
"""Return a Tensor containing the patches
"""
"""Return a Tensor containing the patches"""
def PIL2array(_img: Image.Image) -> np.ndarray:
"""Convert PIL image type to numpy 2D array
"""
"""Convert PIL image type to numpy 2D array"""
return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)
def find_files(_data_dir: str, _image_ext: str) -> List[str]:
"""Return a list with the file names of the images containing the patches
"""
"""Return a list with the file names of the images containing the patches"""
files = []
# find those files with the specified extension
for file_dir in os.listdir(_data_dir):
......@@ -198,22 +207,21 @@ def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
"""Return a Tensor containing the list of labels
Read the file and keep only the ID of the 3D point.
Read the file and keep only the ID of the 3D point.
"""
with open(os.path.join(data_dir, info_file), 'r') as f:
with open(os.path.join(data_dir, info_file), "r") as f:
labels = [int(line.split()[0]) for line in f]
return torch.LongTensor(labels)
def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
"""Return a Tensor containing the ground truth matches
Read the file and keep only 3D point ID.
Matches are represented with a 1, non matches with a 0.
Read the file and keep only 3D point ID.
Matches are represented with a 1, non matches with a 0.
"""
matches = []
with open(os.path.join(data_dir, matches_file), 'r') as f:
with open(os.path.join(data_dir, matches_file), "r") as f:
for line in f:
line_split = line.split()
matches.append([int(line_split[0]), int(line_split[3]),
int(line_split[1] == line_split[4])])
matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
return torch.LongTensor(matches)
from .clip_sampler import DistributedSampler, UniformClipSampler, RandomClipSampler
__all__ = ('DistributedSampler', 'UniformClipSampler', 'RandomClipSampler')
__all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler")
import math
from typing import Optional, List, Iterator, Sized, Union, cast
import torch
from torch.utils.data import Sampler
import torch.distributed as dist
from torch.utils.data import Sampler
from torchvision.datasets.video_utils import VideoClips
from typing import Optional, List, Iterator, Sized, Union, cast
class DistributedSampler(Sampler):
......@@ -36,12 +37,12 @@ class DistributedSampler(Sampler):
"""
def __init__(
self,
dataset: Sized,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = False,
group_size: int = 1,
self,
dataset: Sized,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = False,
group_size: int = 1,
) -> None:
if num_replicas is None:
if not dist.is_available():
......@@ -51,9 +52,11 @@ class DistributedSampler(Sampler):
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
assert len(dataset) % group_size == 0, (
"dataset length must be a multiplier of group size"
"dataset length: %d, group size: %d" % (len(dataset), group_size)
assert (
len(dataset) % group_size == 0
), "dataset length must be a multiplier of group size" "dataset length: %d, group size: %d" % (
len(dataset),
group_size,
)
self.dataset = dataset
self.group_size = group_size
......@@ -61,9 +64,7 @@ class DistributedSampler(Sampler):
self.rank = rank
self.epoch = 0
dataset_group_length = len(dataset) // group_size
self.num_group_samples = int(
math.ceil(dataset_group_length * 1.0 / self.num_replicas)
)
self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas))
self.num_samples = self.num_group_samples * group_size
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
......@@ -79,16 +80,14 @@ class DistributedSampler(Sampler):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
total_group_size = self.total_size // self.group_size
indices = torch.reshape(
torch.LongTensor(indices), (total_group_size, self.group_size)
)
indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size))
# subsample
indices = indices[self.rank:total_group_size:self.num_replicas, :]
indices = indices[self.rank : total_group_size : self.num_replicas, :]
indices = torch.reshape(indices, (-1,)).tolist()
assert len(indices) == self.num_samples
......@@ -115,10 +114,10 @@ class UniformClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from
num_clips_per_video (int): number of clips to be sampled per video
"""
def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips)))
self.video_clips = video_clips
self.num_clips_per_video = num_clips_per_video
......@@ -132,19 +131,13 @@ class UniformClipSampler(Sampler):
# corner case where video decoding fails
continue
sampled = (
torch.linspace(s, s + length - 1, steps=self.num_clips_per_video)
.floor()
.to(torch.int64)
)
sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64)
s += length
idxs.append(sampled)
return iter(cast(List[int], torch.cat(idxs).tolist()))
def __len__(self) -> int:
return sum(
self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0
)
return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0)
class RandomClipSampler(Sampler):
......@@ -155,10 +148,10 @@ class RandomClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
"""
def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips)))
self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video
......
import os
import shutil
from .vision import VisionDataset
from typing import Any, Callable, Optional, Tuple
import numpy as np
from PIL import Image
from .utils import download_url, verify_str_arg, download_and_extract_archive
from .vision import VisionDataset
class SBDataset(VisionDataset):
......@@ -50,30 +50,29 @@ class SBDataset(VisionDataset):
voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722"
def __init__(
self,
root: str,
image_set: str = "train",
mode: str = "boundaries",
download: bool = False,
transforms: Optional[Callable] = None,
self,
root: str,
image_set: str = "train",
mode: str = "boundaries",
download: bool = False,
transforms: Optional[Callable] = None,
) -> None:
try:
from scipy.io import loadmat
self._loadmat = loadmat
except ImportError:
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: "
"pip install scipy")
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: " "pip install scipy")
super(SBDataset, self).__init__(root, transforms)
self.image_set = verify_str_arg(image_set, "image_set",
("train", "val", "train_noval"))
self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval"))
self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
self.num_classes = 20
sbd_root = self.root
image_dir = os.path.join(sbd_root, 'img')
mask_dir = os.path.join(sbd_root, 'cls')
image_dir = os.path.join(sbd_root, "img")
mask_dir = os.path.join(sbd_root, "cls")
if download:
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
......@@ -81,36 +80,35 @@ class SBDataset(VisionDataset):
for f in ["cls", "img", "inst", "train.txt", "val.txt"]:
old_path = os.path.join(extracted_ds_root, f)
shutil.move(old_path, sbd_root)
download_url(self.voc_train_url, sbd_root, self.voc_split_filename,
self.voc_split_md5)
download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5)
if not os.path.isdir(sbd_root):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt')
split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f), "r") as fh:
file_names = [x.strip() for x in fh.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
assert (len(self.images) == len(self.masks))
assert len(self.images) == len(self.masks)
self._get_target = self._get_segmentation_target \
if self.mode == "segmentation" else self._get_boundaries_target
self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target
def _get_segmentation_target(self, filepath: str) -> Image.Image:
mat = self._loadmat(filepath)
return Image.fromarray(mat['GTcls'][0]['Segmentation'][0])
return Image.fromarray(mat["GTcls"][0]["Segmentation"][0])
def _get_boundaries_target(self, filepath: str) -> np.ndarray:
mat = self._loadmat(filepath)
return np.concatenate([np.expand_dims(mat['GTcls'][0]['Boundaries'][0][i][0].toarray(), axis=0)
for i in range(self.num_classes)], axis=0)
return np.concatenate(
[np.expand_dims(mat["GTcls"][0]["Boundaries"][0][i][0].toarray(), axis=0) for i in range(self.num_classes)],
axis=0,
)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img = Image.open(self.images[index]).convert('RGB')
img = Image.open(self.images[index]).convert("RGB")
target = self._get_target(self.masks[index])
if self.transforms is not None:
......@@ -123,4 +121,4 @@ class SBDataset(VisionDataset):
def extra_repr(self) -> str:
lines = ["Image set: {image_set}", "Mode: {mode}"]
return '\n'.join(lines).format(**self.__dict__)
return "\n".join(lines).format(**self.__dict__)
from PIL import Image
from .utils import download_url, check_integrity
import os
from typing import Any, Callable, Optional, Tuple
import os
from PIL import Image
from .utils import download_url, check_integrity
from .vision import VisionDataset
......@@ -20,38 +21,37 @@ class SBU(VisionDataset):
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
filename = "SBUCaptionedPhotoDataset.tar.gz"
md5_checksum = '9aec147b3488753cf758b4d493422285'
md5_checksum = "9aec147b3488753cf758b4d493422285"
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
) -> None:
super(SBU, self).__init__(root, transform=transform,
target_transform=target_transform)
super(SBU, self).__init__(root, transform=transform, target_transform=target_transform)
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
# Read the caption for each photo
self.photos = []
self.captions = []
file1 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')
file2 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_captions.txt')
file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")
file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt")
for line1, line2 in zip(open(file1), open(file2)):
url = line1.rstrip()
photo = os.path.basename(url)
filename = os.path.join(self.root, 'dataset', photo)
filename = os.path.join(self.root, "dataset", photo)
if os.path.exists(filename):
caption = line2.rstrip()
self.photos.append(photo)
......@@ -65,8 +65,8 @@ class SBU(VisionDataset):
Returns:
tuple: (image, target) where target is a caption for the photo.
"""
filename = os.path.join(self.root, 'dataset', self.photos[index])
img = Image.open(filename).convert('RGB')
filename = os.path.join(self.root, "dataset", self.photos[index])
img = Image.open(filename).convert("RGB")
if self.transform is not None:
img = self.transform(img)
......@@ -93,21 +93,21 @@ class SBU(VisionDataset):
import tarfile
if self._check_integrity():
print('Files already downloaded and verified')
print("Files already downloaded and verified")
return
download_url(self.url, self.root, self.filename, self.md5_checksum)
# Extract file
with tarfile.open(os.path.join(self.root, self.filename), 'r:gz') as tar:
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)
# Download individual photos
with open(os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')) as fh:
with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
for line in fh:
url = line.rstrip()
try:
download_url(url, os.path.join(self.root, 'dataset'))
download_url(url, os.path.join(self.root, "dataset"))
except OSError:
# The images point to public images on Flickr.
# Note: Images might be removed by users at anytime.
......
from PIL import Image
import os
import os.path
import numpy as np
from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset
import numpy as np
from PIL import Image
from .utils import download_url, check_integrity
from .vision import VisionDataset
class SEMEION(VisionDataset):
......@@ -24,30 +26,28 @@ class SEMEION(VisionDataset):
"""
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data"
filename = "semeion.data"
md5_checksum = 'cb545d371d2ce14ec121470795a77432'
md5_checksum = "cb545d371d2ce14ec121470795a77432"
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
) -> None:
super(SEMEION, self).__init__(root, transform=transform,
target_transform=target_transform)
super(SEMEION, self).__init__(root, transform=transform, target_transform=target_transform)
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
fp = os.path.join(self.root, self.filename)
data = np.loadtxt(fp)
# convert value to 8 bit unsigned integer
# color (white #255) the pixels
self.data = (data[:, :256] * 255).astype('uint8')
self.data = (data[:, :256] * 255).astype("uint8")
self.data = np.reshape(self.data, (-1, 16, 16))
self.labels = np.nonzero(data[:, 256:])[1]
......@@ -63,7 +63,7 @@ class SEMEION(VisionDataset):
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img, mode='L')
img = Image.fromarray(img, mode="L")
if self.transform is not None:
img = self.transform(img)
......@@ -85,7 +85,7 @@ class SEMEION(VisionDataset):
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
print("Files already downloaded and verified")
return
root = self.root
......
from PIL import Image
import os
import os.path
import numpy as np
from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset
import numpy as np
from PIL import Image
from .utils import check_integrity, download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
class STL10(VisionDataset):
......@@ -27,70 +28,60 @@ class STL10(VisionDataset):
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = 'stl10_binary'
base_folder = "stl10_binary"
url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
filename = "stl10_binary.tar.gz"
tgz_md5 = '91f7769df0f17e558f3565bffb0c7dfb'
class_names_file = 'class_names.txt'
folds_list_file = 'fold_indices.txt'
tgz_md5 = "91f7769df0f17e558f3565bffb0c7dfb"
class_names_file = "class_names.txt"
folds_list_file = "fold_indices.txt"
train_list = [
['train_X.bin', '918c2871b30a85fa023e0c44e0bee87f'],
['train_y.bin', '5a34089d4802c674881badbb80307741'],
['unlabeled_X.bin', '5242ba1fed5e4be9e1e742405eb56ca4']
["train_X.bin", "918c2871b30a85fa023e0c44e0bee87f"],
["train_y.bin", "5a34089d4802c674881badbb80307741"],
["unlabeled_X.bin", "5242ba1fed5e4be9e1e742405eb56ca4"],
]
test_list = [
['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'],
['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e']
]
splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
test_list = [["test_X.bin", "7f263ba9f9e0b06b93213547f721ac82"], ["test_y.bin", "36f9794fa4beb8a2c72628de14fa638e"]]
splits = ("train", "train+unlabeled", "unlabeled", "test")
def __init__(
self,
root: str,
split: str = "train",
folds: Optional[int] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
self,
root: str,
split: str = "train",
folds: Optional[int] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(STL10, self).__init__(root, transform=transform,
target_transform=target_transform)
super(STL10, self).__init__(root, transform=transform, target_transform=target_transform)
self.split = verify_str_arg(split, "split", self.splits)
self.folds = self._verify_folds(folds)
if download:
self.download()
elif not self._check_integrity():
raise RuntimeError(
'Dataset not found or corrupted. '
'You can use download=True to download it')
raise RuntimeError("Dataset not found or corrupted. " "You can use download=True to download it")
# now load the picked numpy arrays
self.labels: Optional[np.ndarray]
if self.split == 'train':
self.data, self.labels = self.__loadfile(
self.train_list[0][0], self.train_list[1][0])
if self.split == "train":
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
self.__load_folds(folds)
elif self.split == 'train+unlabeled':
self.data, self.labels = self.__loadfile(
self.train_list[0][0], self.train_list[1][0])
elif self.split == "train+unlabeled":
self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0])
self.__load_folds(folds)
unlabeled_data, _ = self.__loadfile(self.train_list[2][0])
self.data = np.concatenate((self.data, unlabeled_data))
self.labels = np.concatenate(
(self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
self.labels = np.concatenate((self.labels, np.asarray([-1] * unlabeled_data.shape[0])))
elif self.split == 'unlabeled':
elif self.split == "unlabeled":
self.data, _ = self.__loadfile(self.train_list[2][0])
self.labels = np.asarray([-1] * self.data.shape[0])
else: # self.split == 'test':
self.data, self.labels = self.__loadfile(
self.test_list[0][0], self.test_list[1][0])
self.data, self.labels = self.__loadfile(self.test_list[0][0], self.test_list[1][0])
class_file = os.path.join(
self.root, self.base_folder, self.class_names_file)
class_file = os.path.join(self.root, self.base_folder, self.class_names_file)
if os.path.isfile(class_file):
with open(class_file) as f:
self.classes = f.read().splitlines()
......@@ -101,8 +92,7 @@ class STL10(VisionDataset):
elif isinstance(folds, int):
if folds in range(10):
return folds
msg = ("Value for argument folds should be in the range [0, 10), "
"but got {}.")
msg = "Value for argument folds should be in the range [0, 10), " "but got {}."
raise ValueError(msg.format(folds))
else:
msg = "Expected type None or int for argument folds, but got type {}."
......@@ -140,13 +130,12 @@ class STL10(VisionDataset):
def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
labels = None
if labels_file:
path_to_labels = os.path.join(
self.root, self.base_folder, labels_file)
with open(path_to_labels, 'rb') as f:
path_to_labels = os.path.join(self.root, self.base_folder, labels_file)
with open(path_to_labels, "rb") as f:
labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based
path_to_data = os.path.join(self.root, self.base_folder, data_file)
with open(path_to_data, 'rb') as f:
with open(path_to_data, "rb") as f:
# read whole file in uint8 chunks
everything = np.fromfile(f, dtype=np.uint8)
images = np.reshape(everything, (-1, 3, 96, 96))
......@@ -156,7 +145,7 @@ class STL10(VisionDataset):
def _check_integrity(self) -> bool:
root = self.root
for fentry in (self.train_list + self.test_list):
for fentry in self.train_list + self.test_list:
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
......@@ -165,7 +154,7 @@ class STL10(VisionDataset):
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
print("Files already downloaded and verified")
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
self._check_integrity()
......@@ -177,11 +166,10 @@ class STL10(VisionDataset):
# loads one of the folds if specified
if folds is None:
return
path_to_folds = os.path.join(
self.root, self.base_folder, self.folds_list_file)
with open(path_to_folds, 'r') as f:
path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file)
with open(path_to_folds, "r") as f:
str_idx = f.read().splitlines()[folds]
list_idx = np.fromstring(str_idx, dtype=np.int64, sep=' ')
list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ")
self.data = self.data[list_idx, :, :, :]
if self.labels is not None:
self.labels = self.labels[list_idx]
from .vision import VisionDataset
from PIL import Image
import os
import os.path
import numpy as np
from typing import Any, Callable, Optional, Tuple
import numpy as np
from PIL import Image
from .utils import download_url, check_integrity, verify_str_arg
from .vision import VisionDataset
class SVHN(VisionDataset):
......@@ -33,23 +35,32 @@ class SVHN(VisionDataset):
"""
split_list = {
'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
"train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
"test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
"train": [
"http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
"train_32x32.mat",
"e26dedcc434d2e4c54c9b2d4a06d8373",
],
"test": [
"http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
"test_32x32.mat",
"eb5a983be6a315427106f1b164d9cef3",
],
"extra": [
"http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat",
"a93ce644f1a588dc4d68dda5feec44a7",
],
}
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(SVHN, self).__init__(root, transform=transform,
target_transform=target_transform)
super(SVHN, self).__init__(root, transform=transform, target_transform=target_transform)
self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
self.url = self.split_list[split][0]
self.filename = self.split_list[split][1]
......@@ -59,8 +70,7 @@ class SVHN(VisionDataset):
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
# import here rather than at top of file because this is
# an optional dependency for torchvision
......@@ -69,12 +79,12 @@ class SVHN(VisionDataset):
# reading(loading) mat file as array
loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
self.data = loaded_mat['X']
self.data = loaded_mat["X"]
# loading from the .mat file gives an np array of type np.uint8
# converting to np.int64, so that we have a LongTensor after
# the conversion from the numpy array
# the squeeze is needed to obtain a 1D tensor
self.labels = loaded_mat['y'].astype(np.int64).squeeze()
self.labels = loaded_mat["y"].astype(np.int64).squeeze()
# the svhn dataset assigns the class label "10" to the digit 0
# this makes it inconsistent with several loss functions
......
import os
from typing import Any, Dict, List, Tuple, Optional, Callable
from torch import Tensor
from .folder import find_classes, make_dataset
......@@ -62,13 +63,13 @@ class UCF101(VisionDataset):
_video_width: int = 0,
_video_height: int = 0,
_video_min_dimension: int = 0,
_audio_samples: int = 0
_audio_samples: int = 0,
) -> None:
super(UCF101, self).__init__(root)
if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
extensions = ('avi',)
extensions = ("avi",)
self.fold = fold
self.train = train
......
from PIL import Image
import os
import numpy as np
from typing import Any, Callable, cast, Optional, Tuple
import numpy as np
from PIL import Image
from .utils import download_url
from .vision import VisionDataset
......@@ -26,28 +27,30 @@ class USPS(VisionDataset):
downloaded again.
"""
split_list = {
'train': [
"train": [
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
"usps.bz2", 'ec16c51db3855ca6c91edd34d0e9b197'
"usps.bz2",
"ec16c51db3855ca6c91edd34d0e9b197",
],
'test': [
"test": [
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
"usps.t.bz2", '8ea070ee2aca1ac39742fdd1ef5ed118'
"usps.t.bz2",
"8ea070ee2aca1ac39742fdd1ef5ed118",
],
}
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(USPS, self).__init__(root, transform=transform,
target_transform=target_transform)
split = 'train' if train else 'test'
super(USPS, self).__init__(root, transform=transform, target_transform=target_transform)
split = "train" if train else "test"
url, filename, checksum = self.split_list[split]
full_path = os.path.join(self.root, filename)
......@@ -55,9 +58,10 @@ class USPS(VisionDataset):
download_url(url, self.root, filename, md5=checksum)
import bz2
with bz2.open(full_path) as fp:
raw_data = [line.decode().split() for line in fp.readlines()]
tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]
......@@ -77,7 +81,7 @@ class USPS(VisionDataset):
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img, mode='L')
img = Image.fromarray(img, mode="L")
if self.transform is not None:
img = self.transform(img)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment