Unverified Commit d367a01a authored by Jirka Borovec's avatar Jirka Borovec Committed by GitHub
Browse files

Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)


Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 50dfe207
...@@ -54,7 +54,7 @@ class DistributedSampler(Sampler): ...@@ -54,7 +54,7 @@ class DistributedSampler(Sampler):
rank = dist.get_rank() rank = dist.get_rank()
assert ( assert (
len(dataset) % group_size == 0 len(dataset) % group_size == 0
), "dataset length must be a multiplier of group size" "dataset length: %d, group size: %d" % ( ), "dataset length must be a multiplier of group size dataset length: %d, group size: %d" % (
len(dataset), len(dataset),
group_size, group_size,
) )
...@@ -117,7 +117,7 @@ class UniformClipSampler(Sampler): ...@@ -117,7 +117,7 @@ class UniformClipSampler(Sampler):
def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None: def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
if not isinstance(video_clips, VideoClips): if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips))) raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
self.video_clips = video_clips self.video_clips = video_clips
self.num_clips_per_video = num_clips_per_video self.num_clips_per_video = num_clips_per_video
...@@ -151,7 +151,7 @@ class RandomClipSampler(Sampler): ...@@ -151,7 +151,7 @@ class RandomClipSampler(Sampler):
def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None: def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
if not isinstance(video_clips, VideoClips): if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips))) raise TypeError(f"Expected video_clips to be an instance of VideoClips, got {type(video_clips)}")
self.video_clips = video_clips self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video self.max_clips_per_video = max_clips_per_video
......
...@@ -63,9 +63,9 @@ class SBDataset(VisionDataset): ...@@ -63,9 +63,9 @@ class SBDataset(VisionDataset):
self._loadmat = loadmat self._loadmat = loadmat
except ImportError: except ImportError:
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: " "pip install scipy") raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
super(SBDataset, self).__init__(root, transforms) super().__init__(root, transforms)
self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval")) self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval"))
self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries")) self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
self.num_classes = 20 self.num_classes = 20
...@@ -83,11 +83,11 @@ class SBDataset(VisionDataset): ...@@ -83,11 +83,11 @@ class SBDataset(VisionDataset):
download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5) download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5)
if not os.path.isdir(sbd_root): if not os.path.isdir(sbd_root):
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt") split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f), "r") as fh: with open(os.path.join(split_f)) as fh:
file_names = [x.strip() for x in fh.readlines()] file_names = [x.strip() for x in fh.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
......
...@@ -33,13 +33,13 @@ class SBU(VisionDataset): ...@@ -33,13 +33,13 @@ class SBU(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = True, download: bool = True,
) -> None: ) -> None:
super(SBU, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
# Read the caption for each photo # Read the caption for each photo
self.photos = [] self.photos = []
......
...@@ -35,13 +35,13 @@ class SEMEION(VisionDataset): ...@@ -35,13 +35,13 @@ class SEMEION(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = True, download: bool = True,
) -> None: ) -> None:
super(SEMEION, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
if download: if download:
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
fp = os.path.join(self.root, self.filename) fp = os.path.join(self.root, self.filename)
data = np.loadtxt(fp) data = np.loadtxt(fp)
......
...@@ -53,14 +53,14 @@ class STL10(VisionDataset): ...@@ -53,14 +53,14 @@ class STL10(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(STL10, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.split = verify_str_arg(split, "split", self.splits) self.split = verify_str_arg(split, "split", self.splits)
self.folds = self._verify_folds(folds) self.folds = self._verify_folds(folds)
if download: if download:
self.download() self.download()
elif not self._check_integrity(): elif not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted. " "You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
# now load the picked numpy arrays # now load the picked numpy arrays
self.labels: Optional[np.ndarray] self.labels: Optional[np.ndarray]
...@@ -92,7 +92,7 @@ class STL10(VisionDataset): ...@@ -92,7 +92,7 @@ class STL10(VisionDataset):
elif isinstance(folds, int): elif isinstance(folds, int):
if folds in range(10): if folds in range(10):
return folds return folds
msg = "Value for argument folds should be in the range [0, 10), " "but got {}." msg = "Value for argument folds should be in the range [0, 10), but got {}."
raise ValueError(msg.format(folds)) raise ValueError(msg.format(folds))
else: else:
msg = "Expected type None or int for argument folds, but got type {}." msg = "Expected type None or int for argument folds, but got type {}."
...@@ -167,7 +167,7 @@ class STL10(VisionDataset): ...@@ -167,7 +167,7 @@ class STL10(VisionDataset):
if folds is None: if folds is None:
return return
path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file) path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file)
with open(path_to_folds, "r") as f: with open(path_to_folds) as f:
str_idx = f.read().splitlines()[folds] str_idx = f.read().splitlines()[folds]
list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ") list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ")
self.data = self.data[list_idx, :, :, :] self.data = self.data[list_idx, :, :, :]
......
...@@ -60,7 +60,7 @@ class SVHN(VisionDataset): ...@@ -60,7 +60,7 @@ class SVHN(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(SVHN, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
self.url = self.split_list[split][0] self.url = self.split_list[split][0]
self.filename = self.split_list[split][1] self.filename = self.split_list[split][1]
...@@ -70,7 +70,7 @@ class SVHN(VisionDataset): ...@@ -70,7 +70,7 @@ class SVHN(VisionDataset):
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
# import here rather than at top of file because this is # import here rather than at top of file because this is
# an optional dependency for torchvision # an optional dependency for torchvision
......
...@@ -65,9 +65,9 @@ class UCF101(VisionDataset): ...@@ -65,9 +65,9 @@ class UCF101(VisionDataset):
_video_min_dimension: int = 0, _video_min_dimension: int = 0,
_audio_samples: int = 0, _audio_samples: int = 0,
) -> None: ) -> None:
super(UCF101, self).__init__(root) super().__init__(root)
if not 1 <= fold <= 3: if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold)) raise ValueError(f"fold should be between 1 and 3, got {fold}")
extensions = ("avi",) extensions = ("avi",)
self.fold = fold self.fold = fold
...@@ -102,10 +102,10 @@ class UCF101(VisionDataset): ...@@ -102,10 +102,10 @@ class UCF101(VisionDataset):
def _select_fold(self, video_list: List[str], annotation_path: str, fold: int, train: bool) -> List[int]: def _select_fold(self, video_list: List[str], annotation_path: str, fold: int, train: bool) -> List[int]:
name = "train" if train else "test" name = "train" if train else "test"
name = "{}list{:02d}.txt".format(name, fold) name = f"{name}list{fold:02d}.txt"
f = os.path.join(annotation_path, name) f = os.path.join(annotation_path, name)
selected_files = set() selected_files = set()
with open(f, "r") as fid: with open(f) as fid:
data = fid.readlines() data = fid.readlines()
data = [x.strip().split(" ")[0] for x in data] data = [x.strip().split(" ")[0] for x in data]
data = [os.path.join(self.root, x) for x in data] data = [os.path.join(self.root, x) for x in data]
......
...@@ -49,7 +49,7 @@ class USPS(VisionDataset): ...@@ -49,7 +49,7 @@ class USPS(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(USPS, self).__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
split = "train" if train else "test" split = "train" if train else "test"
url, filename, checksum = self.split_list[split] url, filename, checksum = self.split_list[split]
full_path = os.path.join(self.root, filename) full_path = os.path.join(self.root, filename)
......
...@@ -138,10 +138,10 @@ def download_url( ...@@ -138,10 +138,10 @@ def download_url(
try: try:
print("Downloading " + url + " to " + fpath) print("Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath) _urlretrieve(url, fpath)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined]
if url[:5] == "https": if url[:5] == "https":
url = url.replace("https:", "http:") url = url.replace("https:", "http:")
print("Failed download. Trying https -> http instead." " Downloading " + url + " to " + fpath) print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath) _urlretrieve(url, fpath)
else: else:
raise e raise e
...@@ -428,7 +428,7 @@ def download_and_extract_archive( ...@@ -428,7 +428,7 @@ def download_and_extract_archive(
download_url(url, download_root, filename, md5) download_url(url, download_root, filename, md5)
archive = os.path.join(download_root, filename) archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root)) print(f"Extracting {archive} to {extract_root}")
extract_archive(archive, extract_root, remove_finished) extract_archive(archive, extract_root, remove_finished)
...@@ -460,7 +460,7 @@ def verify_str_arg( ...@@ -460,7 +460,7 @@ def verify_str_arg(
if custom_msg is not None: if custom_msg is not None:
msg = custom_msg msg = custom_msg
else: else:
msg = "Unknown value '{value}' for argument {arg}. " "Valid values are {{{valid_values}}}." msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}."
msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values)) msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
raise ValueError(msg) raise ValueError(msg)
......
...@@ -46,7 +46,7 @@ def unfold(tensor, size, step, dilation=1): ...@@ -46,7 +46,7 @@ def unfold(tensor, size, step, dilation=1):
return torch.as_strided(tensor, new_size, new_stride) return torch.as_strided(tensor, new_size, new_stride)
class _VideoTimestampsDataset(object): class _VideoTimestampsDataset:
""" """
Dataset used to parallelize the reading of the timestamps Dataset used to parallelize the reading of the timestamps
of a list of videos, given their paths in the filesystem. of a list of videos, given their paths in the filesystem.
...@@ -72,7 +72,7 @@ def _collate_fn(x): ...@@ -72,7 +72,7 @@ def _collate_fn(x):
return x return x
class VideoClips(object): class VideoClips:
""" """
Given a list of video files, computes all consecutive subvideos of size Given a list of video files, computes all consecutive subvideos of size
`clip_length_in_frames`, where the distance between each subvideo in the `clip_length_in_frames`, where the distance between each subvideo in the
...@@ -293,7 +293,7 @@ class VideoClips(object): ...@@ -293,7 +293,7 @@ class VideoClips(object):
video_idx (int): index of the video in `video_paths` video_idx (int): index of the video in `video_paths`
""" """
if idx >= self.num_clips(): if idx >= self.num_clips():
raise IndexError("Index {} out of range " "({} number of clips)".format(idx, self.num_clips())) raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)")
video_idx, clip_idx = self.get_clip_location(idx) video_idx, clip_idx = self.get_clip_location(idx)
video_path = self.video_paths[video_idx] video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx] clip_pts = self.clips[video_idx][clip_idx]
...@@ -359,7 +359,7 @@ class VideoClips(object): ...@@ -359,7 +359,7 @@ class VideoClips(object):
resampling_idx = resampling_idx - resampling_idx[0] resampling_idx = resampling_idx - resampling_idx[0]
video = video[resampling_idx] video = video[resampling_idx]
info["video_fps"] = self.frame_rate info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames) assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
return video, audio, info, video_idx return video, audio, info, video_idx
def __getstate__(self): def __getstate__(self):
......
...@@ -43,7 +43,7 @@ class VisionDataset(data.Dataset): ...@@ -43,7 +43,7 @@ class VisionDataset(data.Dataset):
has_transforms = transforms is not None has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform: if has_transforms and has_separate_transform:
raise ValueError("Only transforms or transform/target_transform can " "be passed as argument") raise ValueError("Only transforms or transform/target_transform can be passed as argument")
# for backwards-compatibility # for backwards-compatibility
self.transform = transform self.transform = transform
...@@ -68,9 +68,9 @@ class VisionDataset(data.Dataset): ...@@ -68,9 +68,9 @@ class VisionDataset(data.Dataset):
def __repr__(self) -> str: def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__ head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())] body = [f"Number of datapoints: {self.__len__()}"]
if self.root is not None: if self.root is not None:
body.append("Root location: {}".format(self.root)) body.append(f"Root location: {self.root}")
body += self.extra_repr().splitlines() body += self.extra_repr().splitlines()
if hasattr(self, "transforms") and self.transforms is not None: if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)] body += [repr(self.transforms)]
...@@ -79,13 +79,13 @@ class VisionDataset(data.Dataset): ...@@ -79,13 +79,13 @@ class VisionDataset(data.Dataset):
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines() lines = transform.__repr__().splitlines()
return ["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
def extra_repr(self) -> str: def extra_repr(self) -> str:
return "" return ""
class StandardTransform(object): class StandardTransform:
def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None:
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
...@@ -99,7 +99,7 @@ class StandardTransform(object): ...@@ -99,7 +99,7 @@ class StandardTransform(object):
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines() lines = transform.__repr__().splitlines()
return ["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
def __repr__(self) -> str: def __repr__(self) -> str:
body = [self.__class__.__name__] body = [self.__class__.__name__]
......
...@@ -114,7 +114,7 @@ class _VOCBase(VisionDataset): ...@@ -114,7 +114,7 @@ class _VOCBase(VisionDataset):
splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR) splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt") split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f), "r") as f: with open(os.path.join(split_f)) as f:
file_names = [x.strip() for x in f.readlines()] file_names = [x.strip() for x in f.readlines()]
image_dir = os.path.join(voc_root, "JPEGImages") image_dir = os.path.join(voc_root, "JPEGImages")
......
...@@ -62,7 +62,7 @@ class WIDERFace(VisionDataset): ...@@ -62,7 +62,7 @@ class WIDERFace(VisionDataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(WIDERFace, self).__init__( super().__init__(
root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
) )
# check arguments # check arguments
...@@ -72,9 +72,7 @@ class WIDERFace(VisionDataset): ...@@ -72,9 +72,7 @@ class WIDERFace(VisionDataset):
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError( raise RuntimeError("Dataset not found or corrupted. You can use download=True to download and prepare it")
"Dataset not found or corrupted. " + "You can use download=True to download and prepare it"
)
self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = [] self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
if self.split in ("train", "val"): if self.split in ("train", "val"):
...@@ -115,7 +113,7 @@ class WIDERFace(VisionDataset): ...@@ -115,7 +113,7 @@ class WIDERFace(VisionDataset):
filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt" filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
filepath = os.path.join(self.root, "wider_face_split", filename) filepath = os.path.join(self.root, "wider_face_split", filename)
with open(filepath, "r") as f: with open(filepath) as f:
lines = f.readlines() lines = f.readlines()
file_name_line, num_boxes_line, box_annotation_line = True, False, False file_name_line, num_boxes_line, box_annotation_line = True, False, False
num_boxes, box_counter = 0, 0 num_boxes, box_counter = 0, 0
...@@ -157,12 +155,12 @@ class WIDERFace(VisionDataset): ...@@ -157,12 +155,12 @@ class WIDERFace(VisionDataset):
box_counter = 0 box_counter = 0
labels.clear() labels.clear()
else: else:
raise RuntimeError("Error parsing annotation file {}".format(filepath)) raise RuntimeError(f"Error parsing annotation file {filepath}")
def parse_test_annotations_file(self) -> None: def parse_test_annotations_file(self) -> None:
filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt") filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
filepath = abspath(expanduser(filepath)) filepath = abspath(expanduser(filepath))
with open(filepath, "r") as f: with open(filepath) as f:
lines = f.readlines() lines = f.readlines()
for line in lines: for line in lines:
line = line.rstrip() line = line.rstrip()
......
...@@ -60,10 +60,9 @@ def _check_cuda_version(): ...@@ -60,10 +60,9 @@ def _check_cuda_version():
if t_major != tv_major or t_minor != tv_minor: if t_major != tv_major or t_minor != tv_minor:
raise RuntimeError( raise RuntimeError(
"Detected that PyTorch and torchvision were compiled with different CUDA versions. " "Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. " f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
"Please reinstall the torchvision that matches your PyTorch install.".format( f"CUDA Version={tv_major}.{tv_minor}. "
t_major, t_minor, tv_major, tv_minor "Please reinstall the torchvision that matches your PyTorch install."
)
) )
return _version return _version
......
...@@ -20,7 +20,7 @@ default_timebase = Fraction(0, 1) ...@@ -20,7 +20,7 @@ default_timebase = Fraction(0, 1)
# simple class for torch scripting # simple class for torch scripting
# the complex Fraction class from fractions module is not scriptable # the complex Fraction class from fractions module is not scriptable
class Timebase(object): class Timebase:
__annotations__ = {"numerator": int, "denominator": int} __annotations__ = {"numerator": int, "denominator": int}
__slots__ = ["numerator", "denominator"] __slots__ = ["numerator", "denominator"]
...@@ -34,7 +34,7 @@ class Timebase(object): ...@@ -34,7 +34,7 @@ class Timebase(object):
self.denominator = denominator self.denominator = denominator
class VideoMetaData(object): class VideoMetaData:
__annotations__ = { __annotations__ = {
"has_video": bool, "has_video": bool,
"video_timebase": Timebase, "video_timebase": Timebase,
......
...@@ -161,7 +161,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: ...@@ -161,7 +161,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
JPEG file. JPEG file.
""" """
if quality < 1 or quality > 100: if quality < 1 or quality > 100:
raise ValueError("Image quality should be a positive number " "between 1 and 100") raise ValueError("Image quality should be a positive number between 1 and 100")
output = torch.ops.image.encode_jpeg(input, quality) output = torch.ops.image.encode_jpeg(input, quality)
return output return output
......
...@@ -271,9 +271,7 @@ def read_video( ...@@ -271,9 +271,7 @@ def read_video(
end_pts = float("inf") end_pts = float("inf")
if end_pts < start_pts: if end_pts < start_pts:
raise ValueError( raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
"end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts)
)
info = {} info = {}
video_frames = [] video_frames = []
......
...@@ -54,7 +54,7 @@ class IntermediateLayerGetter(nn.ModuleDict): ...@@ -54,7 +54,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
if not return_layers: if not return_layers:
break break
super(IntermediateLayerGetter, self).__init__(layers) super().__init__(layers)
self.return_layers = orig_return_layers self.return_layers = orig_return_layers
def forward(self, x): def forward(self, x):
......
...@@ -17,7 +17,7 @@ model_urls = { ...@@ -17,7 +17,7 @@ model_urls = {
class AlexNet(nn.Module): class AlexNet(nn.Module):
def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None: def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
super(AlexNet, self).__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
self.features = nn.Sequential( self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
......
...@@ -26,7 +26,7 @@ class _DenseLayer(nn.Module): ...@@ -26,7 +26,7 @@ class _DenseLayer(nn.Module):
def __init__( def __init__(
self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
) -> None: ) -> None:
super(_DenseLayer, self).__init__() super().__init__()
self.norm1: nn.BatchNorm2d self.norm1: nn.BatchNorm2d
self.add_module("norm1", nn.BatchNorm2d(num_input_features)) self.add_module("norm1", nn.BatchNorm2d(num_input_features))
self.relu1: nn.ReLU self.relu1: nn.ReLU
...@@ -107,7 +107,7 @@ class _DenseBlock(nn.ModuleDict): ...@@ -107,7 +107,7 @@ class _DenseBlock(nn.ModuleDict):
drop_rate: float, drop_rate: float,
memory_efficient: bool = False, memory_efficient: bool = False,
) -> None: ) -> None:
super(_DenseBlock, self).__init__() super().__init__()
for i in range(num_layers): for i in range(num_layers):
layer = _DenseLayer( layer = _DenseLayer(
num_input_features + i * growth_rate, num_input_features + i * growth_rate,
...@@ -128,7 +128,7 @@ class _DenseBlock(nn.ModuleDict): ...@@ -128,7 +128,7 @@ class _DenseBlock(nn.ModuleDict):
class _Transition(nn.Sequential): class _Transition(nn.Sequential):
def __init__(self, num_input_features: int, num_output_features: int) -> None: def __init__(self, num_input_features: int, num_output_features: int) -> None:
super(_Transition, self).__init__() super().__init__()
self.add_module("norm", nn.BatchNorm2d(num_input_features)) self.add_module("norm", nn.BatchNorm2d(num_input_features))
self.add_module("relu", nn.ReLU(inplace=True)) self.add_module("relu", nn.ReLU(inplace=True))
self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
...@@ -162,7 +162,7 @@ class DenseNet(nn.Module): ...@@ -162,7 +162,7 @@ class DenseNet(nn.Module):
memory_efficient: bool = False, memory_efficient: bool = False,
) -> None: ) -> None:
super(DenseNet, self).__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
# First convolution # First convolution
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment