Unverified Commit d367a01a authored by Jirka Borovec's avatar Jirka Borovec Committed by GitHub
Browse files

Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)


Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 50dfe207
......@@ -371,7 +371,7 @@ def test_x_crop_save(method, tmpdir):
]
)
scripted_fn = torch.jit.script(fn)
scripted_fn.save(os.path.join(tmpdir, "t_op_list_{}.pt".format(method)))
scripted_fn.save(os.path.join(tmpdir, f"t_op_list_{method}.pt"))
class TestResize:
......@@ -816,7 +816,7 @@ def test_compose(device):
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
assert_equal(transformed_tensor, transformed_tensor_script, msg=f"{transforms}")
t = T.Compose(
[
......@@ -854,7 +854,7 @@ def test_random_apply(device):
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
assert_equal(transformed_tensor, transformed_tensor_script, msg=f"{transforms}")
if device == "cpu":
# Can't check this twice, otherwise
......
......@@ -163,7 +163,7 @@ class TestVideoTransforms:
@pytest.mark.parametrize("p", (0, 1))
def test_random_horizontal_flip_video(self, p):
clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
hclip = clip.flip((-1))
hclip = clip.flip(-1)
out = transforms.RandomHorizontalFlipVideo(p=p)(clip)
if p == 0:
......
......@@ -43,7 +43,7 @@ def set_image_backend(backend):
"""
global _image_backend
if backend not in ["PIL", "accimage"]:
raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'".format(backend))
raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
_image_backend = backend
......@@ -74,7 +74,7 @@ def set_video_backend(backend):
if backend not in ["pyav", "video_reader"]:
raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend)
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
message = "video_reader video backend is not available." " Please compile torchvision from source and try again"
message = "video_reader video backend is not available. Please compile torchvision from source and try again"
warnings.warn(message)
else:
_video_backend = backend
......
......@@ -40,9 +40,7 @@ class Caltech101(VisionDataset):
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Caltech101, self).__init__(
os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform
)
super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
os.makedirs(self.root, exist_ok=True)
if isinstance(target_type, str):
target_type = [target_type]
......@@ -52,7 +50,7 @@ class Caltech101(VisionDataset):
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
self.categories.remove("BACKGROUND_Google") # this is not a real class
......@@ -90,7 +88,7 @@ class Caltech101(VisionDataset):
self.root,
"101_ObjectCategories",
self.categories[self.y[index]],
"image_{:04d}.jpg".format(self.index[index]),
f"image_{self.index[index]:04d}.jpg",
)
)
......@@ -104,7 +102,7 @@ class Caltech101(VisionDataset):
self.root,
"Annotations",
self.annotation_categories[self.y[index]],
"annotation_{:04d}.mat".format(self.index[index]),
f"annotation_{self.index[index]:04d}.mat",
)
)
target.append(data["obj_contour"])
......@@ -167,16 +165,14 @@ class Caltech256(VisionDataset):
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Caltech256, self).__init__(
os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform
)
super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
os.makedirs(self.root, exist_ok=True)
if download:
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
self.index: List[int] = []
......@@ -205,7 +201,7 @@ class Caltech256(VisionDataset):
self.root,
"256_ObjectCategories",
self.categories[self.y[index]],
"{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]),
f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
)
)
......
......@@ -66,7 +66,7 @@ class CelebA(VisionDataset):
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(CelebA, self).__init__(root, transform=transform, target_transform=target_transform)
super().__init__(root, transform=transform, target_transform=target_transform)
self.split = split
if isinstance(target_type, list):
self.target_type = target_type
......@@ -80,7 +80,7 @@ class CelebA(VisionDataset):
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
split_map = {
"train": 0,
......@@ -166,7 +166,7 @@ class CelebA(VisionDataset):
target.append(self.landmarks_align[index, :])
else:
# TODO: refactor with utils.verify_str_arg
raise ValueError('Target type "{}" is not recognized.'.format(t))
raise ValueError(f'Target type "{t}" is not recognized.')
if self.transform is not None:
X = self.transform(X)
......
......@@ -58,7 +58,7 @@ class CIFAR10(VisionDataset):
download: bool = False,
) -> None:
super(CIFAR10, self).__init__(root, transform=transform, target_transform=target_transform)
super().__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set
......@@ -66,7 +66,7 @@ class CIFAR10(VisionDataset):
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
if self.train:
downloaded_list = self.train_list
......@@ -95,9 +95,7 @@ class CIFAR10(VisionDataset):
def _load_meta(self) -> None:
path = os.path.join(self.root, self.base_folder, self.meta["filename"])
if not check_integrity(path, self.meta["md5"]):
raise RuntimeError(
"Dataset metadata file not found or corrupted." + " You can use download=True to download it"
)
raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
with open(path, "rb") as infile:
data = pickle.load(infile, encoding="latin1")
self.classes = data[self.meta["key"]]
......@@ -144,7 +142,8 @@ class CIFAR10(VisionDataset):
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")
split = "Train" if self.train is True else "Test"
return f"Split: {split}"
class CIFAR100(CIFAR10):
......
......@@ -111,7 +111,7 @@ class Cityscapes(VisionDataset):
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
super().__init__(root, transforms, transform, target_transform)
self.mode = "gtFine" if mode == "fine" else "gtCoarse"
self.images_dir = os.path.join(self.root, "leftImg8bit", split)
self.targets_dir = os.path.join(self.root, self.mode, split)
......@@ -125,7 +125,7 @@ class Cityscapes(VisionDataset):
valid_modes = ("train", "test", "val")
else:
valid_modes = ("train", "train_extra", "val")
msg = "Unknown value '{}' for argument split if mode is '{}'. " "Valid values are {{{}}}."
msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
msg = msg.format(split, mode, iterable_to_str(valid_modes))
verify_str_arg(split, "split", valid_modes, msg)
......@@ -139,14 +139,14 @@ class Cityscapes(VisionDataset):
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
if split == "train_extra":
image_dir_zip = os.path.join(self.root, "leftImg8bit{}".format("_trainextra.zip"))
image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip")
else:
image_dir_zip = os.path.join(self.root, "leftImg8bit{}".format("_trainvaltest.zip"))
image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip")
if self.mode == "gtFine":
target_dir_zip = os.path.join(self.root, "{}{}".format(self.mode, "_trainvaltest.zip"))
target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip")
elif self.mode == "gtCoarse":
target_dir_zip = os.path.join(self.root, "{}{}".format(self.mode, ".zip"))
target_dir_zip = os.path.join(self.root, f"{self.mode}.zip")
if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
extract_archive(from_path=image_dir_zip, to_path=self.root)
......@@ -206,16 +206,16 @@ class Cityscapes(VisionDataset):
return "\n".join(lines).format(**self.__dict__)
def _load_json(self, path: str) -> Dict[str, Any]:
with open(path, "r") as file:
with open(path) as file:
data = json.load(file)
return data
def _get_target_suffix(self, mode: str, target_type: str) -> str:
if target_type == "instance":
return "{}_instanceIds.png".format(mode)
return f"{mode}_instanceIds.png"
elif target_type == "semantic":
return "{}_labelIds.png".format(mode)
return f"{mode}_labelIds.png"
elif target_type == "color":
return "{}_color.png".format(mode)
return f"{mode}_color.png"
else:
return "{}_polygons.json".format(mode)
return f"{mode}_polygons.json"
......@@ -31,9 +31,7 @@ class FakeData(VisionDataset):
target_transform: Optional[Callable] = None,
random_offset: int = 0,
) -> None:
super(FakeData, self).__init__(
None, transform=transform, target_transform=target_transform # type: ignore[arg-type]
)
super().__init__(None, transform=transform, target_transform=target_transform) # type: ignore[arg-type]
self.size = size
self.num_classes = num_classes
self.image_size = image_size
......@@ -49,7 +47,7 @@ class FakeData(VisionDataset):
"""
# create random image that is consistent with the index id
if index >= len(self):
raise IndexError("{} index out of range".format(self.__class__.__name__))
raise IndexError(f"{self.__class__.__name__} index out of range")
rng_state = torch.get_rng_state()
torch.manual_seed(index + self.random_offset)
img = torch.randn(*self.image_size)
......
......@@ -13,7 +13,7 @@ class Flickr8kParser(HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page."""
def __init__(self, root: str) -> None:
super(Flickr8kParser, self).__init__()
super().__init__()
self.root = root
......@@ -71,7 +71,7 @@ class Flickr8k(VisionDataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super(Flickr8k, self).__init__(root, transform=transform, target_transform=target_transform)
super().__init__(root, transform=transform, target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict
......@@ -127,7 +127,7 @@ class Flickr30k(VisionDataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super(Flickr30k, self).__init__(root, transform=transform, target_transform=target_transform)
super().__init__(root, transform=transform, target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict
......
......@@ -140,7 +140,7 @@ class DatasetFolder(VisionDataset):
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform)
super().__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
......@@ -254,7 +254,7 @@ def accimage_loader(path: str) -> Any:
try:
return accimage.Image(path)
except IOError:
except OSError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
......@@ -306,7 +306,7 @@ class ImageFolder(DatasetFolder):
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super(ImageFolder, self).__init__(
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
......
......@@ -72,9 +72,9 @@ class HMDB51(VisionDataset):
_video_min_dimension: int = 0,
_audio_samples: int = 0,
) -> None:
super(HMDB51, self).__init__(root)
super().__init__(root)
if fold not in (1, 2, 3):
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
raise ValueError(f"fold should be between 1 and 3, got {fold}")
extensions = ("avi",)
self.classes, class_to_idx = find_classes(self.root)
......@@ -113,7 +113,7 @@ class HMDB51(VisionDataset):
def _select_fold(self, video_list: List[str], annotations_dir: str, fold: int, train: bool) -> List[int]:
target_tag = self.TRAIN_TAG if train else self.TEST_TAG
split_pattern_name = "*test_split{}.txt".format(fold)
split_pattern_name = f"*test_split{fold}.txt"
split_pattern_path = os.path.join(annotations_dir, split_pattern_name)
annotation_paths = glob.glob(split_pattern_path)
selected_files = set()
......
......@@ -49,7 +49,7 @@ class ImageNet(ImageFolder):
)
raise RuntimeError(msg)
elif download is False:
msg = "The use of the download flag is deprecated, since the dataset " "is no longer publicly accessible."
msg = "The use of the download flag is deprecated, since the dataset is no longer publicly accessible."
warnings.warn(msg, RuntimeWarning)
root = self.root = os.path.expanduser(root)
......@@ -58,7 +58,7 @@ class ImageNet(ImageFolder):
self.parse_archives()
wnid_to_classes = load_meta_file(self.root)[0]
super(ImageNet, self).__init__(self.split_folder, **kwargs)
super().__init__(self.split_folder, **kwargs)
self.root = root
self.wnids = self.classes
......@@ -132,7 +132,7 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
with open(file, "r") as txtfh:
with open(file) as txtfh:
val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs]
......@@ -215,7 +215,7 @@ def parse_val_archive(
val_root = os.path.join(root, folder)
extract_archive(os.path.join(root, file), val_root)
images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)])
images = sorted(os.path.join(val_root, image) for image in os.listdir(val_root))
for wnid in set(wnids):
os.mkdir(os.path.join(val_root, wnid))
......
......@@ -74,16 +74,14 @@ class INaturalist(VisionDataset):
) -> None:
self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
super(INaturalist, self).__init__(
os.path.join(root, version), transform=transform, target_transform=target_transform
)
super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
os.makedirs(root, exist_ok=True)
if download:
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self.all_categories: List[str] = []
......
......@@ -175,7 +175,7 @@ class Kinetics(VisionDataset):
split_url_filepath = path.join(file_list_path, path.basename(split_url))
if not check_integrity(split_url_filepath):
download_url(split_url, file_list_path)
list_video_urls = open(split_url_filepath, "r")
list_video_urls = open(split_url_filepath)
if self.num_download_workers == 1:
for line in list_video_urls.readlines():
......@@ -309,7 +309,7 @@ class Kinetics400(Kinetics):
"Kinetics400. Please use Kinetics instead."
)
super(Kinetics400, self).__init__(
super().__init__(
root=root,
frames_per_clip=frames_per_clip,
_legacy=True,
......
......@@ -39,9 +39,7 @@ class _LFW(VisionDataset):
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(_LFW, self).__init__(
os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform
)
super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform)
self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
images_dir, self.filename, self.md5 = self.file_dict[self.image_set]
......@@ -55,7 +53,7 @@ class _LFW(VisionDataset):
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self.images_dir = os.path.join(self.root, images_dir)
......@@ -122,14 +120,14 @@ class LFWPeople(_LFW):
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(LFWPeople, self).__init__(root, split, image_set, "people", transform, target_transform, download)
super().__init__(root, split, image_set, "people", transform, target_transform, download)
self.class_to_idx = self._get_classes()
self.data, self.targets = self._get_people()
def _get_people(self):
data, targets = [], []
with open(os.path.join(self.root, self.labels_file), "r") as f:
with open(os.path.join(self.root, self.labels_file)) as f:
lines = f.readlines()
n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)
......@@ -146,7 +144,7 @@ class LFWPeople(_LFW):
return data, targets
def _get_classes(self):
with open(os.path.join(self.root, self.names), "r") as f:
with open(os.path.join(self.root, self.names)) as f:
lines = f.readlines()
names = [line.strip().split()[0] for line in lines]
class_to_idx = {name: i for i, name in enumerate(names)}
......@@ -172,7 +170,7 @@ class LFWPeople(_LFW):
return img, target
def extra_repr(self) -> str:
return super().extra_repr() + "\nClasses (identities): {}".format(len(self.class_to_idx))
return super().extra_repr() + f"\nClasses (identities): {len(self.class_to_idx)}"
class LFWPairs(_LFW):
......@@ -204,13 +202,13 @@ class LFWPairs(_LFW):
target_transform: Optional[Callable] = None,
download: bool = False,
):
super(LFWPairs, self).__init__(root, split, image_set, "pairs", transform, target_transform, download)
super().__init__(root, split, image_set, "pairs", transform, target_transform, download)
self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
def _get_pairs(self, images_dir):
pair_names, data, targets = [], [], []
with open(os.path.join(self.root, self.labels_file), "r") as f:
with open(os.path.join(self.root, self.labels_file)) as f:
lines = f.readlines()
if self.split == "10fold":
n_folds, n_pairs = lines[0].split("\t")
......
......@@ -18,7 +18,7 @@ class LSUNClass(VisionDataset):
) -> None:
import lmdb
super(LSUNClass, self).__init__(root, transform=transform, target_transform=target_transform)
super().__init__(root, transform=transform, target_transform=target_transform)
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
......@@ -77,7 +77,7 @@ class LSUN(VisionDataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super(LSUN, self).__init__(root, transform=transform, target_transform=target_transform)
super().__init__(root, transform=transform, target_transform=target_transform)
self.classes = self._verify_classes(classes)
# for each class, create an LSUNClassDataset
......@@ -117,11 +117,11 @@ class LSUN(VisionDataset):
classes = [c + "_" + classes for c in categories]
except ValueError:
if not isinstance(classes, Iterable):
msg = "Expected type str or Iterable for argument classes, " "but got type {}."
msg = "Expected type str or Iterable for argument classes, but got type {}."
raise ValueError(msg.format(type(classes)))
classes = list(classes)
msg_fmtstr_type = "Expected type str for elements in argument classes, " "but got type {}."
msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."
for c in classes:
verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
c_short = c.split("_")
......
......@@ -88,7 +88,7 @@ class MNIST(VisionDataset):
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(MNIST, self).__init__(root, transform=transform, target_transform=target_transform)
super().__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set
if self._check_legacy_exist():
......@@ -99,7 +99,7 @@ class MNIST(VisionDataset):
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found." + " You can use download=True to download it")
raise RuntimeError("Dataset not found. You can use download=True to download it")
self.data, self.targets = self._load_data()
......@@ -181,21 +181,22 @@ class MNIST(VisionDataset):
# download files
for filename, md5 in self.resources:
for mirror in self.mirrors:
url = "{}{}".format(mirror, filename)
url = f"{mirror}{filename}"
try:
print("Downloading {}".format(url))
print(f"Downloading {url}")
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
except URLError as error:
print("Failed to download (trying next):\n{}".format(error))
print(f"Failed to download (trying next):\n{error}")
continue
finally:
print()
break
else:
raise RuntimeError("Error downloading {}".format(filename))
raise RuntimeError(f"Error downloading {filename}")
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")
split = "Train" if self.train is True else "Test"
return f"Split: {split}"
class FashionMNIST(MNIST):
......@@ -293,16 +294,16 @@ class EMNIST(MNIST):
self.split = verify_str_arg(split, "split", self.splits)
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs)
super().__init__(root, **kwargs)
self.classes = self.classes_split_dict[self.split]
@staticmethod
def _training_file(split) -> str:
return "training_{}.pt".format(split)
return f"training_{split}.pt"
@staticmethod
def _test_file(split) -> str:
return "test_{}.pt".format(split)
return f"test_{split}.pt"
@property
def _file_prefix(self) -> str:
......@@ -424,7 +425,7 @@ class QMNIST(MNIST):
self.data_file = what + ".pt"
self.training_file = self.data_file
self.test_file = self.data_file
super(QMNIST, self).__init__(root, train, **kwargs)
super().__init__(root, train, **kwargs)
@property
def images_file(self) -> str:
......@@ -482,7 +483,7 @@ class QMNIST(MNIST):
return img, target
def extra_repr(self) -> str:
return "Split: {}".format(self.what)
return f"Split: {self.what}"
def get_int(b: bytes) -> int:
......
......@@ -39,19 +39,19 @@ class Omniglot(VisionDataset):
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Omniglot, self).__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
self.background = background
if download:
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self.target_folder = join(self.root, self._get_target_folder())
self._alphabets = list_dir(self.target_folder)
self._characters: List[str] = sum(
[[join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets], []
([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
)
self._character_images = [
[(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
......
......@@ -89,11 +89,11 @@ class PhotoTour(VisionDataset):
def __init__(
self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
) -> None:
super(PhotoTour, self).__init__(root, transform=transform)
super().__init__(root, transform=transform)
self.name = name
self.data_dir = os.path.join(self.root, name)
self.data_down = os.path.join(self.root, "{}.zip".format(name))
self.data_file = os.path.join(self.root, "{}.pt".format(name))
self.data_down = os.path.join(self.root, f"{name}.zip")
self.data_file = os.path.join(self.root, f"{name}.pt")
self.train = train
self.mean = self.means[name]
......@@ -139,7 +139,7 @@ class PhotoTour(VisionDataset):
def download(self) -> None:
if self._check_datafile_exists():
print("# Found cached data {}".format(self.data_file))
print(f"# Found cached data {self.data_file}")
return
if not self._check_downloaded():
......@@ -151,7 +151,7 @@ class PhotoTour(VisionDataset):
download_url(url, self.root, filename, md5)
print("# Extracting data {}\n".format(self.data_down))
print(f"# Extracting data {self.data_down}\n")
import zipfile
......@@ -162,7 +162,7 @@ class PhotoTour(VisionDataset):
def cache(self) -> None:
# process and save as torch files
print("# Caching data {}".format(self.data_file))
print(f"# Caching data {self.data_file}")
dataset = (
read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
......@@ -174,7 +174,8 @@ class PhotoTour(VisionDataset):
torch.save(dataset, f)
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")
split = "Train" if self.train is True else "Test"
return f"Split: {split}"
def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
......@@ -209,7 +210,7 @@ def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
"""Return a Tensor containing the list of labels
Read the file and keep only the ID of the 3D point.
"""
with open(os.path.join(data_dir, info_file), "r") as f:
with open(os.path.join(data_dir, info_file)) as f:
labels = [int(line.split()[0]) for line in f]
return torch.LongTensor(labels)
......@@ -220,7 +221,7 @@ def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
Matches are represented with a 1, non matches with a 0.
"""
matches = []
with open(os.path.join(data_dir, matches_file), "r") as f:
with open(os.path.join(data_dir, matches_file)) as f:
for line in f:
line_split = line.split()
matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
......
......@@ -117,7 +117,7 @@ class Places365(VisionDataset):
if not self._check_integrity(file, md5, download):
self.download_devkit()
with open(file, "r") as fh:
with open(file) as fh:
class_to_idx = dict(process(line) for line in fh)
return sorted(class_to_idx.keys()), class_to_idx
......@@ -132,7 +132,7 @@ class Places365(VisionDataset):
if not self._check_integrity(file, md5, download):
self.download_devkit()
with open(file, "r") as fh:
with open(file) as fh:
images = [process(line) for line in fh]
_, targets = zip(*images)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment