Unverified Commit d367a01a authored by Jirka Borovec's avatar Jirka Borovec Committed by GitHub
Browse files

Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)


Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 50dfe207
...@@ -371,7 +371,7 @@ def test_x_crop_save(method, tmpdir): ...@@ -371,7 +371,7 @@ def test_x_crop_save(method, tmpdir):
] ]
) )
scripted_fn = torch.jit.script(fn) scripted_fn = torch.jit.script(fn)
scripted_fn.save(os.path.join(tmpdir, "t_op_list_{}.pt".format(method))) scripted_fn.save(os.path.join(tmpdir, f"t_op_list_{method}.pt"))
class TestResize: class TestResize:
...@@ -816,7 +816,7 @@ def test_compose(device): ...@@ -816,7 +816,7 @@ def test_compose(device):
transformed_tensor = transforms(tensor) transformed_tensor = transforms(tensor)
torch.manual_seed(12) torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor) transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms)) assert_equal(transformed_tensor, transformed_tensor_script, msg=f"{transforms}")
t = T.Compose( t = T.Compose(
[ [
...@@ -854,7 +854,7 @@ def test_random_apply(device): ...@@ -854,7 +854,7 @@ def test_random_apply(device):
transformed_tensor = transforms(tensor) transformed_tensor = transforms(tensor)
torch.manual_seed(12) torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor) transformed_tensor_script = scripted_fn(tensor)
assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms)) assert_equal(transformed_tensor, transformed_tensor_script, msg=f"{transforms}")
if device == "cpu": if device == "cpu":
# Can't check this twice, otherwise # Can't check this twice, otherwise
......
...@@ -163,7 +163,7 @@ class TestVideoTransforms: ...@@ -163,7 +163,7 @@ class TestVideoTransforms:
@pytest.mark.parametrize("p", (0, 1)) @pytest.mark.parametrize("p", (0, 1))
def test_random_horizontal_flip_video(self, p): def test_random_horizontal_flip_video(self, p):
clip = torch.rand((3, 4, 112, 112), dtype=torch.float) clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
hclip = clip.flip((-1)) hclip = clip.flip(-1)
out = transforms.RandomHorizontalFlipVideo(p=p)(clip) out = transforms.RandomHorizontalFlipVideo(p=p)(clip)
if p == 0: if p == 0:
......
...@@ -43,7 +43,7 @@ def set_image_backend(backend): ...@@ -43,7 +43,7 @@ def set_image_backend(backend):
""" """
global _image_backend global _image_backend
if backend not in ["PIL", "accimage"]: if backend not in ["PIL", "accimage"]:
raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'".format(backend)) raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
_image_backend = backend _image_backend = backend
...@@ -74,7 +74,7 @@ def set_video_backend(backend): ...@@ -74,7 +74,7 @@ def set_video_backend(backend):
if backend not in ["pyav", "video_reader"]: if backend not in ["pyav", "video_reader"]:
raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend) raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend)
if backend == "video_reader" and not io._HAS_VIDEO_OPT: if backend == "video_reader" and not io._HAS_VIDEO_OPT:
message = "video_reader video backend is not available." " Please compile torchvision from source and try again" message = "video_reader video backend is not available. Please compile torchvision from source and try again"
warnings.warn(message) warnings.warn(message)
else: else:
_video_backend = backend _video_backend = backend
......
...@@ -40,9 +40,7 @@ class Caltech101(VisionDataset): ...@@ -40,9 +40,7 @@ class Caltech101(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(Caltech101, self).__init__( super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform
)
os.makedirs(self.root, exist_ok=True) os.makedirs(self.root, exist_ok=True)
if isinstance(target_type, str): if isinstance(target_type, str):
target_type = [target_type] target_type = [target_type]
...@@ -52,7 +50,7 @@ class Caltech101(VisionDataset): ...@@ -52,7 +50,7 @@ class Caltech101(VisionDataset):
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
self.categories.remove("BACKGROUND_Google") # this is not a real class self.categories.remove("BACKGROUND_Google") # this is not a real class
...@@ -90,7 +88,7 @@ class Caltech101(VisionDataset): ...@@ -90,7 +88,7 @@ class Caltech101(VisionDataset):
self.root, self.root,
"101_ObjectCategories", "101_ObjectCategories",
self.categories[self.y[index]], self.categories[self.y[index]],
"image_{:04d}.jpg".format(self.index[index]), f"image_{self.index[index]:04d}.jpg",
) )
) )
...@@ -104,7 +102,7 @@ class Caltech101(VisionDataset): ...@@ -104,7 +102,7 @@ class Caltech101(VisionDataset):
self.root, self.root,
"Annotations", "Annotations",
self.annotation_categories[self.y[index]], self.annotation_categories[self.y[index]],
"annotation_{:04d}.mat".format(self.index[index]), f"annotation_{self.index[index]:04d}.mat",
) )
) )
target.append(data["obj_contour"]) target.append(data["obj_contour"])
...@@ -167,16 +165,14 @@ class Caltech256(VisionDataset): ...@@ -167,16 +165,14 @@ class Caltech256(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(Caltech256, self).__init__( super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform
)
os.makedirs(self.root, exist_ok=True) os.makedirs(self.root, exist_ok=True)
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
self.index: List[int] = [] self.index: List[int] = []
...@@ -205,7 +201,7 @@ class Caltech256(VisionDataset): ...@@ -205,7 +201,7 @@ class Caltech256(VisionDataset):
self.root, self.root,
"256_ObjectCategories", "256_ObjectCategories",
self.categories[self.y[index]], self.categories[self.y[index]],
"{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]), f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
) )
) )
......
...@@ -66,7 +66,7 @@ class CelebA(VisionDataset): ...@@ -66,7 +66,7 @@ class CelebA(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(CelebA, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.split = split self.split = split
if isinstance(target_type, list): if isinstance(target_type, list):
self.target_type = target_type self.target_type = target_type
...@@ -80,7 +80,7 @@ class CelebA(VisionDataset): ...@@ -80,7 +80,7 @@ class CelebA(VisionDataset):
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
split_map = { split_map = {
"train": 0, "train": 0,
...@@ -166,7 +166,7 @@ class CelebA(VisionDataset): ...@@ -166,7 +166,7 @@ class CelebA(VisionDataset):
target.append(self.landmarks_align[index, :]) target.append(self.landmarks_align[index, :])
else: else:
# TODO: refactor with utils.verify_str_arg # TODO: refactor with utils.verify_str_arg
raise ValueError('Target type "{}" is not recognized.'.format(t)) raise ValueError(f'Target type "{t}" is not recognized.')
if self.transform is not None: if self.transform is not None:
X = self.transform(X) X = self.transform(X)
......
...@@ -58,7 +58,7 @@ class CIFAR10(VisionDataset): ...@@ -58,7 +58,7 @@ class CIFAR10(VisionDataset):
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(CIFAR10, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set self.train = train # training set or test set
...@@ -66,7 +66,7 @@ class CIFAR10(VisionDataset): ...@@ -66,7 +66,7 @@ class CIFAR10(VisionDataset):
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
if self.train: if self.train:
downloaded_list = self.train_list downloaded_list = self.train_list
...@@ -95,9 +95,7 @@ class CIFAR10(VisionDataset): ...@@ -95,9 +95,7 @@ class CIFAR10(VisionDataset):
def _load_meta(self) -> None: def _load_meta(self) -> None:
path = os.path.join(self.root, self.base_folder, self.meta["filename"]) path = os.path.join(self.root, self.base_folder, self.meta["filename"])
if not check_integrity(path, self.meta["md5"]): if not check_integrity(path, self.meta["md5"]):
raise RuntimeError( raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
"Dataset metadata file not found or corrupted." + " You can use download=True to download it"
)
with open(path, "rb") as infile: with open(path, "rb") as infile:
data = pickle.load(infile, encoding="latin1") data = pickle.load(infile, encoding="latin1")
self.classes = data[self.meta["key"]] self.classes = data[self.meta["key"]]
...@@ -144,7 +142,8 @@ class CIFAR10(VisionDataset): ...@@ -144,7 +142,8 @@ class CIFAR10(VisionDataset):
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test") split = "Train" if self.train is True else "Test"
return f"Split: {split}"
class CIFAR100(CIFAR10): class CIFAR100(CIFAR10):
......
...@@ -111,7 +111,7 @@ class Cityscapes(VisionDataset): ...@@ -111,7 +111,7 @@ class Cityscapes(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
) -> None: ) -> None:
super(Cityscapes, self).__init__(root, transforms, transform, target_transform) super().__init__(root, transforms, transform, target_transform)
self.mode = "gtFine" if mode == "fine" else "gtCoarse" self.mode = "gtFine" if mode == "fine" else "gtCoarse"
self.images_dir = os.path.join(self.root, "leftImg8bit", split) self.images_dir = os.path.join(self.root, "leftImg8bit", split)
self.targets_dir = os.path.join(self.root, self.mode, split) self.targets_dir = os.path.join(self.root, self.mode, split)
...@@ -125,7 +125,7 @@ class Cityscapes(VisionDataset): ...@@ -125,7 +125,7 @@ class Cityscapes(VisionDataset):
valid_modes = ("train", "test", "val") valid_modes = ("train", "test", "val")
else: else:
valid_modes = ("train", "train_extra", "val") valid_modes = ("train", "train_extra", "val")
msg = "Unknown value '{}' for argument split if mode is '{}'. " "Valid values are {{{}}}." msg = "Unknown value '{}' for argument split if mode is '{}'. Valid values are {{{}}}."
msg = msg.format(split, mode, iterable_to_str(valid_modes)) msg = msg.format(split, mode, iterable_to_str(valid_modes))
verify_str_arg(split, "split", valid_modes, msg) verify_str_arg(split, "split", valid_modes, msg)
...@@ -139,14 +139,14 @@ class Cityscapes(VisionDataset): ...@@ -139,14 +139,14 @@ class Cityscapes(VisionDataset):
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
if split == "train_extra": if split == "train_extra":
image_dir_zip = os.path.join(self.root, "leftImg8bit{}".format("_trainextra.zip")) image_dir_zip = os.path.join(self.root, "leftImg8bit_trainextra.zip")
else: else:
image_dir_zip = os.path.join(self.root, "leftImg8bit{}".format("_trainvaltest.zip")) image_dir_zip = os.path.join(self.root, "leftImg8bit_trainvaltest.zip")
if self.mode == "gtFine": if self.mode == "gtFine":
target_dir_zip = os.path.join(self.root, "{}{}".format(self.mode, "_trainvaltest.zip")) target_dir_zip = os.path.join(self.root, f"{self.mode}_trainvaltest.zip")
elif self.mode == "gtCoarse": elif self.mode == "gtCoarse":
target_dir_zip = os.path.join(self.root, "{}{}".format(self.mode, ".zip")) target_dir_zip = os.path.join(self.root, f"{self.mode}.zip")
if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip):
extract_archive(from_path=image_dir_zip, to_path=self.root) extract_archive(from_path=image_dir_zip, to_path=self.root)
...@@ -206,16 +206,16 @@ class Cityscapes(VisionDataset): ...@@ -206,16 +206,16 @@ class Cityscapes(VisionDataset):
return "\n".join(lines).format(**self.__dict__) return "\n".join(lines).format(**self.__dict__)
def _load_json(self, path: str) -> Dict[str, Any]: def _load_json(self, path: str) -> Dict[str, Any]:
with open(path, "r") as file: with open(path) as file:
data = json.load(file) data = json.load(file)
return data return data
def _get_target_suffix(self, mode: str, target_type: str) -> str: def _get_target_suffix(self, mode: str, target_type: str) -> str:
if target_type == "instance": if target_type == "instance":
return "{}_instanceIds.png".format(mode) return f"{mode}_instanceIds.png"
elif target_type == "semantic": elif target_type == "semantic":
return "{}_labelIds.png".format(mode) return f"{mode}_labelIds.png"
elif target_type == "color": elif target_type == "color":
return "{}_color.png".format(mode) return f"{mode}_color.png"
else: else:
return "{}_polygons.json".format(mode) return f"{mode}_polygons.json"
...@@ -31,9 +31,7 @@ class FakeData(VisionDataset): ...@@ -31,9 +31,7 @@ class FakeData(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
random_offset: int = 0, random_offset: int = 0,
) -> None: ) -> None:
super(FakeData, self).__init__( super().__init__(None, transform=transform, target_transform=target_transform) # type: ignore[arg-type]
None, transform=transform, target_transform=target_transform # type: ignore[arg-type]
)
self.size = size self.size = size
self.num_classes = num_classes self.num_classes = num_classes
self.image_size = image_size self.image_size = image_size
...@@ -49,7 +47,7 @@ class FakeData(VisionDataset): ...@@ -49,7 +47,7 @@ class FakeData(VisionDataset):
""" """
# create random image that is consistent with the index id # create random image that is consistent with the index id
if index >= len(self): if index >= len(self):
raise IndexError("{} index out of range".format(self.__class__.__name__)) raise IndexError(f"{self.__class__.__name__} index out of range")
rng_state = torch.get_rng_state() rng_state = torch.get_rng_state()
torch.manual_seed(index + self.random_offset) torch.manual_seed(index + self.random_offset)
img = torch.randn(*self.image_size) img = torch.randn(*self.image_size)
......
...@@ -13,7 +13,7 @@ class Flickr8kParser(HTMLParser): ...@@ -13,7 +13,7 @@ class Flickr8kParser(HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page.""" """Parser for extracting captions from the Flickr8k dataset web page."""
def __init__(self, root: str) -> None: def __init__(self, root: str) -> None:
super(Flickr8kParser, self).__init__() super().__init__()
self.root = root self.root = root
...@@ -71,7 +71,7 @@ class Flickr8k(VisionDataset): ...@@ -71,7 +71,7 @@ class Flickr8k(VisionDataset):
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
) -> None: ) -> None:
super(Flickr8k, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file) self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict # Read annotations and store in a dict
...@@ -127,7 +127,7 @@ class Flickr30k(VisionDataset): ...@@ -127,7 +127,7 @@ class Flickr30k(VisionDataset):
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
) -> None: ) -> None:
super(Flickr30k, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file) self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict # Read annotations and store in a dict
......
...@@ -140,7 +140,7 @@ class DatasetFolder(VisionDataset): ...@@ -140,7 +140,7 @@ class DatasetFolder(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None: ) -> None:
super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root) classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
...@@ -254,7 +254,7 @@ def accimage_loader(path: str) -> Any: ...@@ -254,7 +254,7 @@ def accimage_loader(path: str) -> Any:
try: try:
return accimage.Image(path) return accimage.Image(path)
except IOError: except OSError:
# Potentially a decoding problem, fall back to PIL.Image # Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path) return pil_loader(path)
...@@ -306,7 +306,7 @@ class ImageFolder(DatasetFolder): ...@@ -306,7 +306,7 @@ class ImageFolder(DatasetFolder):
loader: Callable[[str], Any] = default_loader, loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None, is_valid_file: Optional[Callable[[str], bool]] = None,
): ):
super(ImageFolder, self).__init__( super().__init__(
root, root,
loader, loader,
IMG_EXTENSIONS if is_valid_file is None else None, IMG_EXTENSIONS if is_valid_file is None else None,
......
...@@ -72,9 +72,9 @@ class HMDB51(VisionDataset): ...@@ -72,9 +72,9 @@ class HMDB51(VisionDataset):
_video_min_dimension: int = 0, _video_min_dimension: int = 0,
_audio_samples: int = 0, _audio_samples: int = 0,
) -> None: ) -> None:
super(HMDB51, self).__init__(root) super().__init__(root)
if fold not in (1, 2, 3): if fold not in (1, 2, 3):
raise ValueError("fold should be between 1 and 3, got {}".format(fold)) raise ValueError(f"fold should be between 1 and 3, got {fold}")
extensions = ("avi",) extensions = ("avi",)
self.classes, class_to_idx = find_classes(self.root) self.classes, class_to_idx = find_classes(self.root)
...@@ -113,7 +113,7 @@ class HMDB51(VisionDataset): ...@@ -113,7 +113,7 @@ class HMDB51(VisionDataset):
def _select_fold(self, video_list: List[str], annotations_dir: str, fold: int, train: bool) -> List[int]: def _select_fold(self, video_list: List[str], annotations_dir: str, fold: int, train: bool) -> List[int]:
target_tag = self.TRAIN_TAG if train else self.TEST_TAG target_tag = self.TRAIN_TAG if train else self.TEST_TAG
split_pattern_name = "*test_split{}.txt".format(fold) split_pattern_name = f"*test_split{fold}.txt"
split_pattern_path = os.path.join(annotations_dir, split_pattern_name) split_pattern_path = os.path.join(annotations_dir, split_pattern_name)
annotation_paths = glob.glob(split_pattern_path) annotation_paths = glob.glob(split_pattern_path)
selected_files = set() selected_files = set()
......
...@@ -49,7 +49,7 @@ class ImageNet(ImageFolder): ...@@ -49,7 +49,7 @@ class ImageNet(ImageFolder):
) )
raise RuntimeError(msg) raise RuntimeError(msg)
elif download is False: elif download is False:
msg = "The use of the download flag is deprecated, since the dataset " "is no longer publicly accessible." msg = "The use of the download flag is deprecated, since the dataset is no longer publicly accessible."
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
root = self.root = os.path.expanduser(root) root = self.root = os.path.expanduser(root)
...@@ -58,7 +58,7 @@ class ImageNet(ImageFolder): ...@@ -58,7 +58,7 @@ class ImageNet(ImageFolder):
self.parse_archives() self.parse_archives()
wnid_to_classes = load_meta_file(self.root)[0] wnid_to_classes = load_meta_file(self.root)[0]
super(ImageNet, self).__init__(self.split_folder, **kwargs) super().__init__(self.split_folder, **kwargs)
self.root = root self.root = root
self.wnids = self.classes self.wnids = self.classes
...@@ -132,7 +132,7 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None: ...@@ -132,7 +132,7 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
def parse_val_groundtruth_txt(devkit_root: str) -> List[int]: def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt") file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
with open(file, "r") as txtfh: with open(file) as txtfh:
val_idcs = txtfh.readlines() val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs] return [int(val_idx) for val_idx in val_idcs]
...@@ -215,7 +215,7 @@ def parse_val_archive( ...@@ -215,7 +215,7 @@ def parse_val_archive(
val_root = os.path.join(root, folder) val_root = os.path.join(root, folder)
extract_archive(os.path.join(root, file), val_root) extract_archive(os.path.join(root, file), val_root)
images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)]) images = sorted(os.path.join(val_root, image) for image in os.listdir(val_root))
for wnid in set(wnids): for wnid in set(wnids):
os.mkdir(os.path.join(val_root, wnid)) os.mkdir(os.path.join(val_root, wnid))
......
...@@ -74,16 +74,14 @@ class INaturalist(VisionDataset): ...@@ -74,16 +74,14 @@ class INaturalist(VisionDataset):
) -> None: ) -> None:
self.version = verify_str_arg(version, "version", DATASET_URLS.keys()) self.version = verify_str_arg(version, "version", DATASET_URLS.keys())
super(INaturalist, self).__init__( super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)
os.path.join(root, version), transform=transform, target_transform=target_transform
)
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self.all_categories: List[str] = [] self.all_categories: List[str] = []
......
...@@ -175,7 +175,7 @@ class Kinetics(VisionDataset): ...@@ -175,7 +175,7 @@ class Kinetics(VisionDataset):
split_url_filepath = path.join(file_list_path, path.basename(split_url)) split_url_filepath = path.join(file_list_path, path.basename(split_url))
if not check_integrity(split_url_filepath): if not check_integrity(split_url_filepath):
download_url(split_url, file_list_path) download_url(split_url, file_list_path)
list_video_urls = open(split_url_filepath, "r") list_video_urls = open(split_url_filepath)
if self.num_download_workers == 1: if self.num_download_workers == 1:
for line in list_video_urls.readlines(): for line in list_video_urls.readlines():
...@@ -309,7 +309,7 @@ class Kinetics400(Kinetics): ...@@ -309,7 +309,7 @@ class Kinetics400(Kinetics):
"Kinetics400. Please use Kinetics instead." "Kinetics400. Please use Kinetics instead."
) )
super(Kinetics400, self).__init__( super().__init__(
root=root, root=root,
frames_per_clip=frames_per_clip, frames_per_clip=frames_per_clip,
_legacy=True, _legacy=True,
......
...@@ -39,9 +39,7 @@ class _LFW(VisionDataset): ...@@ -39,9 +39,7 @@ class _LFW(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
): ):
super(_LFW, self).__init__( super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform)
os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform
)
self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys()) self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
images_dir, self.filename, self.md5 = self.file_dict[self.image_set] images_dir, self.filename, self.md5 = self.file_dict[self.image_set]
...@@ -55,7 +53,7 @@ class _LFW(VisionDataset): ...@@ -55,7 +53,7 @@ class _LFW(VisionDataset):
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self.images_dir = os.path.join(self.root, images_dir) self.images_dir = os.path.join(self.root, images_dir)
...@@ -122,14 +120,14 @@ class LFWPeople(_LFW): ...@@ -122,14 +120,14 @@ class LFWPeople(_LFW):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
): ):
super(LFWPeople, self).__init__(root, split, image_set, "people", transform, target_transform, download) super().__init__(root, split, image_set, "people", transform, target_transform, download)
self.class_to_idx = self._get_classes() self.class_to_idx = self._get_classes()
self.data, self.targets = self._get_people() self.data, self.targets = self._get_people()
def _get_people(self): def _get_people(self):
data, targets = [], [] data, targets = [], []
with open(os.path.join(self.root, self.labels_file), "r") as f: with open(os.path.join(self.root, self.labels_file)) as f:
lines = f.readlines() lines = f.readlines()
n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0) n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0)
...@@ -146,7 +144,7 @@ class LFWPeople(_LFW): ...@@ -146,7 +144,7 @@ class LFWPeople(_LFW):
return data, targets return data, targets
def _get_classes(self): def _get_classes(self):
with open(os.path.join(self.root, self.names), "r") as f: with open(os.path.join(self.root, self.names)) as f:
lines = f.readlines() lines = f.readlines()
names = [line.strip().split()[0] for line in lines] names = [line.strip().split()[0] for line in lines]
class_to_idx = {name: i for i, name in enumerate(names)} class_to_idx = {name: i for i, name in enumerate(names)}
...@@ -172,7 +170,7 @@ class LFWPeople(_LFW): ...@@ -172,7 +170,7 @@ class LFWPeople(_LFW):
return img, target return img, target
def extra_repr(self) -> str: def extra_repr(self) -> str:
return super().extra_repr() + "\nClasses (identities): {}".format(len(self.class_to_idx)) return super().extra_repr() + f"\nClasses (identities): {len(self.class_to_idx)}"
class LFWPairs(_LFW): class LFWPairs(_LFW):
...@@ -204,13 +202,13 @@ class LFWPairs(_LFW): ...@@ -204,13 +202,13 @@ class LFWPairs(_LFW):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
): ):
super(LFWPairs, self).__init__(root, split, image_set, "pairs", transform, target_transform, download) super().__init__(root, split, image_set, "pairs", transform, target_transform, download)
self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir) self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
def _get_pairs(self, images_dir): def _get_pairs(self, images_dir):
pair_names, data, targets = [], [], [] pair_names, data, targets = [], [], []
with open(os.path.join(self.root, self.labels_file), "r") as f: with open(os.path.join(self.root, self.labels_file)) as f:
lines = f.readlines() lines = f.readlines()
if self.split == "10fold": if self.split == "10fold":
n_folds, n_pairs = lines[0].split("\t") n_folds, n_pairs = lines[0].split("\t")
......
...@@ -18,7 +18,7 @@ class LSUNClass(VisionDataset): ...@@ -18,7 +18,7 @@ class LSUNClass(VisionDataset):
) -> None: ) -> None:
import lmdb import lmdb
super(LSUNClass, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
with self.env.begin(write=False) as txn: with self.env.begin(write=False) as txn:
...@@ -77,7 +77,7 @@ class LSUN(VisionDataset): ...@@ -77,7 +77,7 @@ class LSUN(VisionDataset):
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
) -> None: ) -> None:
super(LSUN, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.classes = self._verify_classes(classes) self.classes = self._verify_classes(classes)
# for each class, create an LSUNClassDataset # for each class, create an LSUNClassDataset
...@@ -117,11 +117,11 @@ class LSUN(VisionDataset): ...@@ -117,11 +117,11 @@ class LSUN(VisionDataset):
classes = [c + "_" + classes for c in categories] classes = [c + "_" + classes for c in categories]
except ValueError: except ValueError:
if not isinstance(classes, Iterable): if not isinstance(classes, Iterable):
msg = "Expected type str or Iterable for argument classes, " "but got type {}." msg = "Expected type str or Iterable for argument classes, but got type {}."
raise ValueError(msg.format(type(classes))) raise ValueError(msg.format(type(classes)))
classes = list(classes) classes = list(classes)
msg_fmtstr_type = "Expected type str for elements in argument classes, " "but got type {}." msg_fmtstr_type = "Expected type str for elements in argument classes, but got type {}."
for c in classes: for c in classes:
verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c))) verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c)))
c_short = c.split("_") c_short = c.split("_")
......
...@@ -88,7 +88,7 @@ class MNIST(VisionDataset): ...@@ -88,7 +88,7 @@ class MNIST(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(MNIST, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set self.train = train # training set or test set
if self._check_legacy_exist(): if self._check_legacy_exist():
...@@ -99,7 +99,7 @@ class MNIST(VisionDataset): ...@@ -99,7 +99,7 @@ class MNIST(VisionDataset):
self.download() self.download()
if not self._check_exists(): if not self._check_exists():
raise RuntimeError("Dataset not found." + " You can use download=True to download it") raise RuntimeError("Dataset not found. You can use download=True to download it")
self.data, self.targets = self._load_data() self.data, self.targets = self._load_data()
...@@ -181,21 +181,22 @@ class MNIST(VisionDataset): ...@@ -181,21 +181,22 @@ class MNIST(VisionDataset):
# download files # download files
for filename, md5 in self.resources: for filename, md5 in self.resources:
for mirror in self.mirrors: for mirror in self.mirrors:
url = "{}{}".format(mirror, filename) url = f"{mirror}{filename}"
try: try:
print("Downloading {}".format(url)) print(f"Downloading {url}")
download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
except URLError as error: except URLError as error:
print("Failed to download (trying next):\n{}".format(error)) print(f"Failed to download (trying next):\n{error}")
continue continue
finally: finally:
print() print()
break break
else: else:
raise RuntimeError("Error downloading {}".format(filename)) raise RuntimeError(f"Error downloading {filename}")
def extra_repr(self) -> str: def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test") split = "Train" if self.train is True else "Test"
return f"Split: {split}"
class FashionMNIST(MNIST): class FashionMNIST(MNIST):
...@@ -293,16 +294,16 @@ class EMNIST(MNIST): ...@@ -293,16 +294,16 @@ class EMNIST(MNIST):
self.split = verify_str_arg(split, "split", self.splits) self.split = verify_str_arg(split, "split", self.splits)
self.training_file = self._training_file(split) self.training_file = self._training_file(split)
self.test_file = self._test_file(split) self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs) super().__init__(root, **kwargs)
self.classes = self.classes_split_dict[self.split] self.classes = self.classes_split_dict[self.split]
@staticmethod @staticmethod
def _training_file(split) -> str: def _training_file(split) -> str:
return "training_{}.pt".format(split) return f"training_{split}.pt"
@staticmethod @staticmethod
def _test_file(split) -> str: def _test_file(split) -> str:
return "test_{}.pt".format(split) return f"test_{split}.pt"
@property @property
def _file_prefix(self) -> str: def _file_prefix(self) -> str:
...@@ -424,7 +425,7 @@ class QMNIST(MNIST): ...@@ -424,7 +425,7 @@ class QMNIST(MNIST):
self.data_file = what + ".pt" self.data_file = what + ".pt"
self.training_file = self.data_file self.training_file = self.data_file
self.test_file = self.data_file self.test_file = self.data_file
super(QMNIST, self).__init__(root, train, **kwargs) super().__init__(root, train, **kwargs)
@property @property
def images_file(self) -> str: def images_file(self) -> str:
...@@ -482,7 +483,7 @@ class QMNIST(MNIST): ...@@ -482,7 +483,7 @@ class QMNIST(MNIST):
return img, target return img, target
def extra_repr(self) -> str: def extra_repr(self) -> str:
return "Split: {}".format(self.what) return f"Split: {self.what}"
def get_int(b: bytes) -> int: def get_int(b: bytes) -> int:
......
...@@ -39,19 +39,19 @@ class Omniglot(VisionDataset): ...@@ -39,19 +39,19 @@ class Omniglot(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(Omniglot, self).__init__(join(root, self.folder), transform=transform, target_transform=target_transform) super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
self.background = background self.background = background
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
self.target_folder = join(self.root, self._get_target_folder()) self.target_folder = join(self.root, self._get_target_folder())
self._alphabets = list_dir(self.target_folder) self._alphabets = list_dir(self.target_folder)
self._characters: List[str] = sum( self._characters: List[str] = sum(
[[join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets], [] ([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
) )
self._character_images = [ self._character_images = [
[(image, idx) for image in list_files(join(self.target_folder, character), ".png")] [(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
......
...@@ -89,11 +89,11 @@ class PhotoTour(VisionDataset): ...@@ -89,11 +89,11 @@ class PhotoTour(VisionDataset):
def __init__( def __init__(
self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
) -> None: ) -> None:
super(PhotoTour, self).__init__(root, transform=transform) super().__init__(root, transform=transform)
self.name = name self.name = name
self.data_dir = os.path.join(self.root, name) self.data_dir = os.path.join(self.root, name)
self.data_down = os.path.join(self.root, "{}.zip".format(name)) self.data_down = os.path.join(self.root, f"{name}.zip")
self.data_file = os.path.join(self.root, "{}.pt".format(name)) self.data_file = os.path.join(self.root, f"{name}.pt")
self.train = train self.train = train
self.mean = self.means[name] self.mean = self.means[name]
...@@ -139,7 +139,7 @@ class PhotoTour(VisionDataset): ...@@ -139,7 +139,7 @@ class PhotoTour(VisionDataset):
def download(self) -> None: def download(self) -> None:
if self._check_datafile_exists(): if self._check_datafile_exists():
print("# Found cached data {}".format(self.data_file)) print(f"# Found cached data {self.data_file}")
return return
if not self._check_downloaded(): if not self._check_downloaded():
...@@ -151,7 +151,7 @@ class PhotoTour(VisionDataset): ...@@ -151,7 +151,7 @@ class PhotoTour(VisionDataset):
download_url(url, self.root, filename, md5) download_url(url, self.root, filename, md5)
print("# Extracting data {}\n".format(self.data_down)) print(f"# Extracting data {self.data_down}\n")
import zipfile import zipfile
...@@ -162,7 +162,7 @@ class PhotoTour(VisionDataset): ...@@ -162,7 +162,7 @@ class PhotoTour(VisionDataset):
def cache(self) -> None: def cache(self) -> None:
# process and save as torch files # process and save as torch files
print("# Caching data {}".format(self.data_file)) print(f"# Caching data {self.data_file}")
dataset = ( dataset = (
read_image_file(self.data_dir, self.image_ext, self.lens[self.name]), read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
...@@ -174,7 +174,8 @@ class PhotoTour(VisionDataset): ...@@ -174,7 +174,8 @@ class PhotoTour(VisionDataset):
torch.save(dataset, f) torch.save(dataset, f)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test") split = "Train" if self.train is True else "Test"
return f"Split: {split}"
def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor: def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
...@@ -209,7 +210,7 @@ def read_info_file(data_dir: str, info_file: str) -> torch.Tensor: ...@@ -209,7 +210,7 @@ def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
"""Return a Tensor containing the list of labels """Return a Tensor containing the list of labels
Read the file and keep only the ID of the 3D point. Read the file and keep only the ID of the 3D point.
""" """
with open(os.path.join(data_dir, info_file), "r") as f: with open(os.path.join(data_dir, info_file)) as f:
labels = [int(line.split()[0]) for line in f] labels = [int(line.split()[0]) for line in f]
return torch.LongTensor(labels) return torch.LongTensor(labels)
...@@ -220,7 +221,7 @@ def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor: ...@@ -220,7 +221,7 @@ def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
Matches are represented with a 1, non matches with a 0. Matches are represented with a 1, non matches with a 0.
""" """
matches = [] matches = []
with open(os.path.join(data_dir, matches_file), "r") as f: with open(os.path.join(data_dir, matches_file)) as f:
for line in f: for line in f:
line_split = line.split() line_split = line.split()
matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])]) matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
......
...@@ -117,7 +117,7 @@ class Places365(VisionDataset): ...@@ -117,7 +117,7 @@ class Places365(VisionDataset):
if not self._check_integrity(file, md5, download): if not self._check_integrity(file, md5, download):
self.download_devkit() self.download_devkit()
with open(file, "r") as fh: with open(file) as fh:
class_to_idx = dict(process(line) for line in fh) class_to_idx = dict(process(line) for line in fh)
return sorted(class_to_idx.keys()), class_to_idx return sorted(class_to_idx.keys()), class_to_idx
...@@ -132,7 +132,7 @@ class Places365(VisionDataset): ...@@ -132,7 +132,7 @@ class Places365(VisionDataset):
if not self._check_integrity(file, md5, download): if not self._check_integrity(file, md5, download):
self.download_devkit() self.download_devkit()
with open(file, "r") as fh: with open(file) as fh:
images = [process(line) for line in fh] images = [process(line) for line in fh]
_, targets = zip(*images) _, targets = zip(*images)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment