"references/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "55a8300b24063cb5be37bcb6c3cc19dc0852b813"
Unverified Commit d367a01a authored by Jirka Borovec's avatar Jirka Borovec Committed by GitHub
Browse files

Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)


Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 50dfe207
...@@ -8,7 +8,7 @@ import torch.nn as nn ...@@ -8,7 +8,7 @@ import torch.nn as nn
class TripletMarginLoss(nn.Module): class TripletMarginLoss(nn.Module):
def __init__(self, margin=1.0, p=2.0, mining="batch_all"): def __init__(self, margin=1.0, p=2.0, mining="batch_all"):
super(TripletMarginLoss, self).__init__() super().__init__()
self.margin = margin self.margin = margin
self.p = p self.p = p
self.mining = mining self.mining = mining
......
...@@ -4,7 +4,7 @@ import torchvision.models as models ...@@ -4,7 +4,7 @@ import torchvision.models as models
class EmbeddingNet(nn.Module): class EmbeddingNet(nn.Module):
def __init__(self, backbone=None): def __init__(self, backbone=None):
super(EmbeddingNet, self).__init__() super().__init__()
if backbone is None: if backbone is None:
backbone = models.resnet50(num_classes=128) backbone = models.resnet50(num_classes=128)
......
...@@ -31,7 +31,7 @@ def train_epoch(model, optimizer, criterion, data_loader, device, epoch, print_f ...@@ -31,7 +31,7 @@ def train_epoch(model, optimizer, criterion, data_loader, device, epoch, print_f
i += 1 i += 1
avg_loss = running_loss / print_freq avg_loss = running_loss / print_freq
avg_trip = 100.0 * running_frac_pos_triplets / print_freq avg_trip = 100.0 * running_frac_pos_triplets / print_freq
print("[{:d}, {:d}] | loss: {:.4f} | % avg hard triplets: {:.2f}%".format(epoch, i, avg_loss, avg_trip)) print(f"[{epoch:d}, {i:d}] | loss: {avg_loss:.4f} | % avg hard triplets: {avg_trip:.2f}%")
running_loss = 0 running_loss = 0
running_frac_pos_triplets = 0 running_frac_pos_triplets = 0
...@@ -77,7 +77,7 @@ def evaluate(model, loader, device): ...@@ -77,7 +77,7 @@ def evaluate(model, loader, device):
threshold, accuracy = find_best_threshold(dists, targets, device) threshold, accuracy = find_best_threshold(dists, targets, device)
print("accuracy: {:.3f}%, threshold: {:.2f}".format(accuracy, threshold)) print(f"accuracy: {accuracy:.3f}%, threshold: {threshold:.2f}")
def save(model, epoch, save_dir, file_name): def save(model, epoch, save_dir, file_name):
......
...@@ -24,7 +24,7 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi ...@@ -24,7 +24,7 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("clips/s", utils.SmoothedValue(window_size=10, fmt="{value:.3f}")) metric_logger.add_meter("clips/s", utils.SmoothedValue(window_size=10, fmt="{value:.3f}"))
header = "Epoch: [{}]".format(epoch) header = f"Epoch: [{epoch}]"
for video, target in metric_logger.log_every(data_loader, print_freq, header): for video, target in metric_logger.log_every(data_loader, print_freq, header):
start_time = time.time() start_time = time.time()
video, target = video.to(device), target.to(device) video, target = video.to(device), target.to(device)
...@@ -122,12 +122,12 @@ def main(args): ...@@ -122,12 +122,12 @@ def main(args):
transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112)) transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112))
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
print("Loading dataset_train from {}".format(cache_path)) print(f"Loading dataset_train from {cache_path}")
dataset, _ = torch.load(cache_path) dataset, _ = torch.load(cache_path)
dataset.transform = transform_train dataset.transform = transform_train
else: else:
if args.distributed: if args.distributed:
print("It is recommended to pre-compute the dataset cache " "on a single-gpu first, as it will be faster") print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
dataset = torchvision.datasets.Kinetics400( dataset = torchvision.datasets.Kinetics400(
traindir, traindir,
frames_per_clip=args.clip_len, frames_per_clip=args.clip_len,
...@@ -140,7 +140,7 @@ def main(args): ...@@ -140,7 +140,7 @@ def main(args):
), ),
) )
if args.cache_dataset: if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path)) print(f"Saving dataset_train to {cache_path}")
utils.mkdir(os.path.dirname(cache_path)) utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path) utils.save_on_master((dataset, traindir), cache_path)
...@@ -152,12 +152,12 @@ def main(args): ...@@ -152,12 +152,12 @@ def main(args):
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
print("Loading dataset_test from {}".format(cache_path)) print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path) dataset_test, _ = torch.load(cache_path)
dataset_test.transform = transform_test dataset_test.transform = transform_test
else: else:
if args.distributed: if args.distributed:
print("It is recommended to pre-compute the dataset cache " "on a single-gpu first, as it will be faster") print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
dataset_test = torchvision.datasets.Kinetics400( dataset_test = torchvision.datasets.Kinetics400(
valdir, valdir,
frames_per_clip=args.clip_len, frames_per_clip=args.clip_len,
...@@ -170,7 +170,7 @@ def main(args): ...@@ -170,7 +170,7 @@ def main(args):
), ),
) )
if args.cache_dataset: if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path)) print(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path)) utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path) utils.save_on_master((dataset_test, valdir), cache_path)
...@@ -232,8 +232,7 @@ def main(args): ...@@ -232,8 +232,7 @@ def main(args):
) )
else: else:
raise RuntimeError( raise RuntimeError(
"Invalid warmup lr method '{}'. Only linear and constant " f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
"are supported.".format(args.lr_warmup_method)
) )
lr_scheduler = torch.optim.lr_scheduler.SequentialLR( lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
...@@ -275,12 +274,12 @@ def main(args): ...@@ -275,12 +274,12 @@ def main(args):
"epoch": epoch, "epoch": epoch,
"args": args, "args": args,
} }
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str)) print(f"Training time {total_time_str}")
def parse_args(): def parse_args():
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
class SmoothedValue(object): class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a """Track a series of values and provide access to smoothed values over a
window or the global series average. window or the global series average.
""" """
...@@ -67,7 +67,7 @@ class SmoothedValue(object): ...@@ -67,7 +67,7 @@ class SmoothedValue(object):
) )
class MetricLogger(object): class MetricLogger:
def __init__(self, delimiter="\t"): def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue) self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter self.delimiter = delimiter
...@@ -84,12 +84,12 @@ class MetricLogger(object): ...@@ -84,12 +84,12 @@ class MetricLogger(object):
return self.meters[attr] return self.meters[attr]
if attr in self.__dict__: if attr in self.__dict__:
return self.__dict__[attr] return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self): def __str__(self):
loss_str = [] loss_str = []
for name, meter in self.meters.items(): for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter))) loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str) return self.delimiter.join(loss_str)
def synchronize_between_processes(self): def synchronize_between_processes(self):
...@@ -154,7 +154,7 @@ class MetricLogger(object): ...@@ -154,7 +154,7 @@ class MetricLogger(object):
end = time.time() end = time.time()
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {}".format(header, total_time_str)) print(f"{header} Total time: {total_time_str}")
def accuracy(output, target, topk=(1,)): def accuracy(output, target, topk=(1,)):
...@@ -246,7 +246,7 @@ def init_distributed_mode(args): ...@@ -246,7 +246,7 @@ def init_distributed_mode(args):
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl" args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
) )
......
...@@ -14,7 +14,7 @@ df.tail() ...@@ -14,7 +14,7 @@ df.tail()
# In[3]: # In[3]:
all_labels = set(lbl for labels in df["labels"] for lbl in labels) all_labels = {lbl for labels in df["labels"] for lbl in labels}
all_labels all_labels
......
...@@ -96,7 +96,7 @@ def run_query(query): ...@@ -96,7 +96,7 @@ def run_query(query):
if request.status_code == 200: if request.status_code == 200:
return request.json() return request.json()
else: else:
raise Exception("Query failed to run by returning code of {}. {}".format(request.status_code, query)) raise Exception(f"Query failed to run by returning code of {request.status_code}. {query}")
def gh_labels(pr_number): def gh_labels(pr_number):
...@@ -151,7 +151,7 @@ class CommitDataCache: ...@@ -151,7 +151,7 @@ class CommitDataCache:
return self.data[commit] return self.data[commit]
def read_from_disk(self): def read_from_disk(self):
with open(self.path, "r") as f: with open(self.path) as f:
data = json.load(f) data = json.load(f)
data = {commit: dict_to_features(dct) for commit, dct in data.items()} data = {commit: dict_to_features(dct) for commit, dct in data.items()}
return data return data
......
import distutils.command.clean import distutils.command.clean
import distutils.spawn import distutils.spawn
import glob import glob
import io
import os import os
import shutil import shutil
import subprocess import subprocess
...@@ -14,7 +13,7 @@ from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtensio ...@@ -14,7 +13,7 @@ from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtensio
def read(*names, **kwargs): def read(*names, **kwargs):
with io.open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp: with open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp:
return fp.read() return fp.read()
...@@ -28,7 +27,7 @@ def get_dist(pkgname): ...@@ -28,7 +27,7 @@ def get_dist(pkgname):
cwd = os.path.dirname(os.path.abspath(__file__)) cwd = os.path.dirname(os.path.abspath(__file__))
version_txt = os.path.join(cwd, "version.txt") version_txt = os.path.join(cwd, "version.txt")
with open(version_txt, "r") as f: with open(version_txt) as f:
version = f.readline().strip() version = f.readline().strip()
sha = "Unknown" sha = "Unknown"
package_name = "torchvision" package_name = "torchvision"
...@@ -47,8 +46,8 @@ elif sha != "Unknown": ...@@ -47,8 +46,8 @@ elif sha != "Unknown":
def write_version_file(): def write_version_file():
version_path = os.path.join(cwd, "torchvision", "version.py") version_path = os.path.join(cwd, "torchvision", "version.py")
with open(version_path, "w") as f: with open(version_path, "w") as f:
f.write("__version__ = '{}'\n".format(version)) f.write(f"__version__ = '{version}'\n")
f.write("git_version = {}\n".format(repr(sha))) f.write(f"git_version = {repr(sha)}\n")
f.write("from torchvision.extension import _check_cuda_version\n") f.write("from torchvision.extension import _check_cuda_version\n")
f.write("if _check_cuda_version() > 0:\n") f.write("if _check_cuda_version() > 0:\n")
f.write(" cuda = _check_cuda_version()\n") f.write(" cuda = _check_cuda_version()\n")
...@@ -78,7 +77,7 @@ def find_library(name, vision_include): ...@@ -78,7 +77,7 @@ def find_library(name, vision_include):
conda_installed = False conda_installed = False
lib_folder = None lib_folder = None
include_folder = None include_folder = None
library_header = "{0}.h".format(name) library_header = f"{name}.h"
# Lookup in TORCHVISION_INCLUDE or in the package file # Lookup in TORCHVISION_INCLUDE or in the package file
package_path = [os.path.join(this_dir, "torchvision")] package_path = [os.path.join(this_dir, "torchvision")]
...@@ -89,7 +88,7 @@ def find_library(name, vision_include): ...@@ -89,7 +88,7 @@ def find_library(name, vision_include):
break break
if not library_found: if not library_found:
print("Running build on conda-build: {0}".format(is_conda_build)) print(f"Running build on conda-build: {is_conda_build}")
if is_conda_build: if is_conda_build:
# Add conda headers/libraries # Add conda headers/libraries
if os.name == "nt": if os.name == "nt":
...@@ -103,7 +102,7 @@ def find_library(name, vision_include): ...@@ -103,7 +102,7 @@ def find_library(name, vision_include):
# Check if using Anaconda to produce wheels # Check if using Anaconda to produce wheels
conda = distutils.spawn.find_executable("conda") conda = distutils.spawn.find_executable("conda")
is_conda = conda is not None is_conda = conda is not None
print("Running build on conda: {0}".format(is_conda)) print(f"Running build on conda: {is_conda}")
if is_conda: if is_conda:
python_executable = sys.executable python_executable = sys.executable
py_folder = os.path.dirname(python_executable) py_folder = os.path.dirname(python_executable)
...@@ -119,8 +118,8 @@ def find_library(name, vision_include): ...@@ -119,8 +118,8 @@ def find_library(name, vision_include):
if not library_found: if not library_found:
if sys.platform == "linux": if sys.platform == "linux":
library_found = os.path.exists("/usr/include/{0}".format(library_header)) library_found = os.path.exists(f"/usr/include/{library_header}")
library_found = library_found or os.path.exists("/usr/local/include/{0}".format(library_header)) library_found = library_found or os.path.exists(f"/usr/local/include/{library_header}")
return library_found, conda_installed, include_folder, lib_folder return library_found, conda_installed, include_folder, lib_folder
...@@ -258,13 +257,13 @@ def get_extensions(): ...@@ -258,13 +257,13 @@ def get_extensions():
libpng = distutils.spawn.find_executable("libpng-config") libpng = distutils.spawn.find_executable("libpng-config")
pngfix = distutils.spawn.find_executable("pngfix") pngfix = distutils.spawn.find_executable("pngfix")
png_found = libpng is not None or pngfix is not None png_found = libpng is not None or pngfix is not None
print("PNG found: {0}".format(png_found)) print(f"PNG found: {png_found}")
if png_found: if png_found:
if libpng is not None: if libpng is not None:
# Linux / Mac # Linux / Mac
png_version = subprocess.run([libpng, "--version"], stdout=subprocess.PIPE) png_version = subprocess.run([libpng, "--version"], stdout=subprocess.PIPE)
png_version = png_version.stdout.strip().decode("utf-8") png_version = png_version.stdout.strip().decode("utf-8")
print("libpng version: {0}".format(png_version)) print(f"libpng version: {png_version}")
png_version = parse_version(png_version) png_version = parse_version(png_version)
if png_version >= parse_version("1.6.0"): if png_version >= parse_version("1.6.0"):
print("Building torchvision with PNG image support") print("Building torchvision with PNG image support")
...@@ -275,11 +274,11 @@ def get_extensions(): ...@@ -275,11 +274,11 @@ def get_extensions():
png_include = subprocess.run([libpng, "--I_opts"], stdout=subprocess.PIPE) png_include = subprocess.run([libpng, "--I_opts"], stdout=subprocess.PIPE)
png_include = png_include.stdout.strip().decode("utf-8") png_include = png_include.stdout.strip().decode("utf-8")
_, png_include = png_include.split("-I") _, png_include = png_include.split("-I")
print("libpng include path: {0}".format(png_include)) print(f"libpng include path: {png_include}")
image_include += [png_include] image_include += [png_include]
image_link_flags.append("png") image_link_flags.append("png")
else: else:
print("libpng installed version is less than 1.6.0, " "disabling PNG support") print("libpng installed version is less than 1.6.0, disabling PNG support")
png_found = False png_found = False
else: else:
# Windows # Windows
...@@ -292,7 +291,7 @@ def get_extensions(): ...@@ -292,7 +291,7 @@ def get_extensions():
# Locating libjpeg # Locating libjpeg
(jpeg_found, jpeg_conda, jpeg_include, jpeg_lib) = find_library("jpeglib", vision_include) (jpeg_found, jpeg_conda, jpeg_include, jpeg_lib) = find_library("jpeglib", vision_include)
print("JPEG found: {0}".format(jpeg_found)) print(f"JPEG found: {jpeg_found}")
image_macros += [("PNG_FOUND", str(int(png_found)))] image_macros += [("PNG_FOUND", str(int(png_found)))]
image_macros += [("JPEG_FOUND", str(int(jpeg_found)))] image_macros += [("JPEG_FOUND", str(int(jpeg_found)))]
if jpeg_found: if jpeg_found:
...@@ -310,7 +309,7 @@ def get_extensions(): ...@@ -310,7 +309,7 @@ def get_extensions():
and os.path.exists(os.path.join(CUDA_HOME, "include", "nvjpeg.h")) and os.path.exists(os.path.join(CUDA_HOME, "include", "nvjpeg.h"))
) )
print("NVJPEG found: {0}".format(nvjpeg_found)) print(f"NVJPEG found: {nvjpeg_found}")
image_macros += [("NVJPEG_FOUND", str(int(nvjpeg_found)))] image_macros += [("NVJPEG_FOUND", str(int(nvjpeg_found)))]
if nvjpeg_found: if nvjpeg_found:
print("Building torchvision with NVJPEG image support") print("Building torchvision with NVJPEG image support")
...@@ -352,7 +351,7 @@ def get_extensions(): ...@@ -352,7 +351,7 @@ def get_extensions():
print("Error fetching ffmpeg version, ignoring ffmpeg.") print("Error fetching ffmpeg version, ignoring ffmpeg.")
has_ffmpeg = False has_ffmpeg = False
print("FFmpeg found: {}".format(has_ffmpeg)) print(f"FFmpeg found: {has_ffmpeg}")
if has_ffmpeg: if has_ffmpeg:
ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"} ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"}
...@@ -386,8 +385,8 @@ def get_extensions(): ...@@ -386,8 +385,8 @@ def get_extensions():
has_ffmpeg = False has_ffmpeg = False
if has_ffmpeg: if has_ffmpeg:
print("ffmpeg include path: {}".format(ffmpeg_include_dir)) print(f"ffmpeg include path: {ffmpeg_include_dir}")
print("ffmpeg library_dir: {}".format(ffmpeg_library_dir)) print(f"ffmpeg library_dir: {ffmpeg_library_dir}")
# TorchVision base decoder + video reader # TorchVision base decoder + video reader
video_reader_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video_reader") video_reader_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video_reader")
...@@ -432,7 +431,7 @@ def get_extensions(): ...@@ -432,7 +431,7 @@ def get_extensions():
class clean(distutils.command.clean.clean): class clean(distutils.command.clean.clean):
def run(self): def run(self):
with open(".gitignore", "r") as f: with open(".gitignore") as f:
ignores = f.read() ignores = f.read()
for wildcard in filter(None, ignores.split("\n")): for wildcard in filter(None, ignores.split("\n")):
for filename in glob.glob(wildcard): for filename in glob.glob(wildcard):
...@@ -446,7 +445,7 @@ class clean(distutils.command.clean.clean): ...@@ -446,7 +445,7 @@ class clean(distutils.command.clean.clean):
if __name__ == "__main__": if __name__ == "__main__":
print("Building wheel {}-{}".format(package_name, version)) print(f"Building wheel {package_name}-{version}")
write_version_file() write_version_file()
...@@ -472,6 +471,7 @@ if __name__ == "__main__": ...@@ -472,6 +471,7 @@ if __name__ == "__main__":
"scipy": ["scipy"], "scipy": ["scipy"],
}, },
ext_modules=get_extensions(), ext_modules=get_extensions(),
python_requires=">=3.6",
cmdclass={ cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True), "build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean, "clean": clean,
......
...@@ -37,7 +37,7 @@ def set_rng_seed(seed): ...@@ -37,7 +37,7 @@ def set_rng_seed(seed):
random.seed(seed) random.seed(seed)
class MapNestedTensorObjectImpl(object): class MapNestedTensorObjectImpl:
def __init__(self, tensor_map_fn): def __init__(self, tensor_map_fn):
self.tensor_map_fn = tensor_map_fn self.tensor_map_fn = tensor_map_fn
...@@ -152,7 +152,7 @@ def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None): ...@@ -152,7 +152,7 @@ def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
else: else:
f = fps[i] f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8) data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmpdir, "{}.mp4".format(i)) name = os.path.join(tmpdir, f"{i}.mp4")
names.append(name) names.append(name)
io.write_video(name, data, fps=f) io.write_video(name, data, fps=f)
...@@ -165,7 +165,7 @@ def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): ...@@ -165,7 +165,7 @@ def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
np_pil_image = np_pil_image[:, :, None] np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))) pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None: if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor) msg = f"tensor:\n{tensor} \ndid not equal PIL tensor:\n{pil_tensor}"
assert_equal(tensor.cpu(), pil_tensor, msg=msg) assert_equal(tensor.cpu(), pil_tensor, msg=msg)
......
...@@ -133,7 +133,7 @@ def test_all_configs(test): ...@@ -133,7 +133,7 @@ def test_all_configs(test):
def maybe_remove_duplicates(configs): def maybe_remove_duplicates(configs):
try: try:
return [dict(config_) for config_ in set(tuple(sorted(config.items())) for config in configs)] return [dict(config_) for config_ in {tuple(sorted(config.items())) for config in configs}]
except TypeError: except TypeError:
# A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate # A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
# removal would be a lot more elaborate and we simply bail out. # removal would be a lot more elaborate and we simply bail out.
......
...@@ -26,7 +26,7 @@ if __name__ == "__main__": ...@@ -26,7 +26,7 @@ if __name__ == "__main__":
if args.accimage: if args.accimage:
torchvision.set_image_backend("accimage") torchvision.set_image_backend("accimage")
print("Using {}".format(torchvision.get_image_backend())) print(f"Using {torchvision.get_image_backend()}")
# Data loading code # Data loading code
transform = transforms.Compose( transform = transforms.Compose(
......
...@@ -144,7 +144,7 @@ class TestFxFeatureExtraction: ...@@ -144,7 +144,7 @@ class TestFxFeatureExtraction:
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
) )
out = model(self.inp) out = model(self.inp)
sum([o.mean() for o in out.values()]).backward() sum(o.mean() for o in out.values()).backward()
def test_feature_extraction_methods_equivalence(self): def test_feature_extraction_methods_equivalence(self):
model = models.resnet18(**self.model_defaults).eval() model = models.resnet18(**self.model_defaults).eval()
...@@ -176,7 +176,7 @@ class TestFxFeatureExtraction: ...@@ -176,7 +176,7 @@ class TestFxFeatureExtraction:
) )
model = torch.jit.script(model) model = torch.jit.script(model)
fgn_out = model(self.inp) fgn_out = model(self.inp)
sum([o.mean() for o in fgn_out.values()]).backward() sum(o.mean() for o in fgn_out.values()).backward()
def test_train_eval(self): def test_train_eval(self):
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
......
...@@ -654,7 +654,7 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -654,7 +654,7 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
shutil.copytree(src, root / "Segmentation") shutil.copytree(src, root / "Segmentation")
num_images = max(itertools.chain(*idcs.values())) + 1 num_images = max(itertools.chain(*idcs.values())) + 1
num_images_per_image_set = dict([(image_set, len(idcs_)) for image_set, idcs_ in idcs.items()]) num_images_per_image_set = {image_set: len(idcs_) for image_set, idcs_ in idcs.items()}
return num_images, num_images_per_image_set return num_images, num_images_per_image_set
def _create_image_set_file(self, root, image_set, idcs): def _create_image_set_file(self, root, image_set, idcs):
...@@ -1174,7 +1174,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1174,7 +1174,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
self._create_split_file(root, split, idcs) self._create_split_file(root, split, idcs)
num_images = max(itertools.chain(*splits.values())) + 1 num_images = max(itertools.chain(*splits.values())) + 1
num_images_per_split = dict([(split, len(idcs)) for split, idcs in splits.items()]) num_images_per_split = {split: len(idcs) for split, idcs in splits.items()}
return num_images, num_images_per_split return num_images, num_images_per_split
def _create_split_file(self, root, name, idcs): def _create_split_file(self, root, name, idcs):
......
...@@ -129,7 +129,7 @@ class TestDatasetsUtils: ...@@ -129,7 +129,7 @@ class TestDatasetsUtils:
assert os.path.exists(file) assert os.path.exists(file)
with open(file, "r") as fh: with open(file) as fh:
assert fh.read() == content assert fh.read() == content
def test_decompress_no_compression(self): def test_decompress_no_compression(self):
...@@ -179,7 +179,7 @@ class TestDatasetsUtils: ...@@ -179,7 +179,7 @@ class TestDatasetsUtils:
assert os.path.exists(file) assert os.path.exists(file)
with open(file, "r") as fh: with open(file) as fh:
assert fh.read() == content assert fh.read() == content
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -205,7 +205,7 @@ class TestDatasetsUtils: ...@@ -205,7 +205,7 @@ class TestDatasetsUtils:
assert os.path.exists(file) assert os.path.exists(file)
with open(file, "r") as fh: with open(file) as fh:
assert fh.read() == content assert fh.read() == content
def test_verify_str_arg(self): def test_verify_str_arg(self):
......
...@@ -111,9 +111,9 @@ class TestRotate: ...@@ -111,9 +111,9 @@ class TestRotate:
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8) out_tensor = out_tensor.to(torch.uint8)
assert out_tensor.shape == out_pil_tensor.shape, ( assert (
f"{(height, width, NEAREST, dt, angle, expand, center)}: " f"{out_tensor.shape} vs {out_pil_tensor.shape}" out_tensor.shape == out_pil_tensor.shape
) ), f"{(height, width, NEAREST, dt, angle, expand, center)}: {out_tensor.shape} vs {out_pil_tensor.shape}"
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
...@@ -177,11 +177,11 @@ class TestAffine: ...@@ -177,11 +177,11 @@ class TestAffine:
# 1) identity map # 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
out_tensor = self.scripted_affine( out_tensor = self.scripted_affine(
tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
) )
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("height, width", [(26, 26)]) @pytest.mark.parametrize("height, width", [(26, 26)])
...@@ -225,9 +225,7 @@ class TestAffine: ...@@ -225,9 +225,7 @@ class TestAffine:
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 6% of different pixels # Tolerance : less than 6% of different pixels
assert ratio_diff_pixels < 0.06, "{}\n{} vs \n{}".format( assert ratio_diff_pixels < 0.06
ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("height, width", [(32, 26)]) @pytest.mark.parametrize("height, width", [(32, 26)])
...@@ -258,9 +256,7 @@ class TestAffine: ...@@ -258,9 +256,7 @@ class TestAffine:
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels # Tolerance : less than 3% of different pixels
assert ratio_diff_pixels < 0.03, "{}: {}\n{} vs \n{}".format( assert ratio_diff_pixels < 0.03
angle, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)])
...@@ -346,9 +342,7 @@ class TestAffine: ...@@ -346,9 +342,7 @@ class TestAffine:
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% (cpu), 6% (cuda) of different pixels # Tolerance : less than 5% (cpu), 6% (cuda) of different pixels
tol = 0.06 if device == "cuda" else 0.05 tol = 0.06 if device == "cuda" else 0.05
assert ratio_diff_pixels < tol, "{}: {}\n{} vs \n{}".format( assert ratio_diff_pixels < tol
(NEAREST, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dt", ALL_DTYPES) @pytest.mark.parametrize("dt", ALL_DTYPES)
...@@ -936,7 +930,7 @@ def test_pad(device, dt, pad, config): ...@@ -936,7 +930,7 @@ def test_pad(device, dt, pad, config):
if pad_tensor_8b.dtype != torch.uint8: if pad_tensor_8b.dtype != torch.uint8:
pad_tensor_8b = pad_tensor_8b.to(torch.uint8) pad_tensor_8b = pad_tensor_8b.to(torch.uint8)
_assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, config)) _assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg=f"{pad}, {config}")
if isinstance(pad, int): if isinstance(pad, int):
script_pad = [ script_pad = [
...@@ -945,7 +939,7 @@ def test_pad(device, dt, pad, config): ...@@ -945,7 +939,7 @@ def test_pad(device, dt, pad, config):
else: else:
script_pad = pad script_pad = pad
pad_tensor_script = script_fn(tensor, script_pad, **config) pad_tensor_script = script_fn(tensor, script_pad, **config)
assert_equal(pad_tensor, pad_tensor_script, msg="{}, {}".format(pad, config)) assert_equal(pad_tensor, pad_tensor_script, msg=f"{pad}, {config}")
_test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config) _test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config)
...@@ -958,7 +952,7 @@ def test_resized_crop(device, mode): ...@@ -958,7 +952,7 @@ def test_resized_crop(device, mode):
tensor, _ = _create_data(26, 36, device=device) tensor, _ = _create_data(26, 36, device=device)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode) out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) assert_equal(tensor, out_tensor, msg=f"{out_tensor[0, :5, :5]} vs {tensor[0, :5, :5]}")
# 2) resize by half and crop a TL corner # 2) resize by half and crop a TL corner
tensor, _ = _create_data(26, 36, device=device) tensor, _ = _create_data(26, 36, device=device)
...@@ -967,7 +961,7 @@ def test_resized_crop(device, mode): ...@@ -967,7 +961,7 @@ def test_resized_crop(device, mode):
assert_equal( assert_equal(
expected_out_tensor, expected_out_tensor,
out_tensor, out_tensor,
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]), msg=f"{expected_out_tensor[0, :10, :10]} vs {out_tensor[0, :10, :10]}",
) )
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device) batch_tensors = _create_data_batch(26, 36, num_samples=4, device=device)
...@@ -1126,7 +1120,7 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): ...@@ -1126,7 +1120,7 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
_ksize = (ksize, ksize) if isinstance(ksize, int) else ksize _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
_sigma = sigma[0] if sigma is not None else None _sigma = sigma[0] if sigma is not None else None
shape = tensor.shape shape = tensor.shape
gt_key = "{}_{}_{}__{}_{}_{}".format(shape[-2], shape[-1], shape[-3], _ksize[0], _ksize[1], _sigma) gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
if gt_key not in true_cv2_results: if gt_key not in true_cv2_results:
return return
...@@ -1135,7 +1129,7 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): ...@@ -1135,7 +1129,7 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
) )
out = fn(tensor, kernel_size=ksize, sigma=sigma) out = fn(tensor, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg="{}, {}".format(ksize, sigma)) torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
......
...@@ -233,7 +233,7 @@ def test_write_png(img_path, tmpdir): ...@@ -233,7 +233,7 @@ def test_write_png(img_path, tmpdir):
img_pil = img_pil.permute(2, 0, 1) img_pil = img_pil.permute(2, 0, 1)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_png = os.path.join(tmpdir, "{0}_torch.png".format(filename)) torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
write_png(img_pil, torch_png, compression_level=6) write_png(img_pil, torch_png, compression_level=6)
saved_image = torch.from_numpy(np.array(Image.open(torch_png))) saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
saved_image = saved_image.permute(2, 0, 1) saved_image = saved_image.permute(2, 0, 1)
...@@ -393,10 +393,10 @@ def test_encode_jpeg_errors(): ...@@ -393,10 +393,10 @@ def test_encode_jpeg_errors():
with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
with pytest.raises(ValueError, match="Image quality should be a positive number " "between 1 and 100"): with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
with pytest.raises(ValueError, match="Image quality should be a positive number " "between 1 and 100"): with pytest.raises(ValueError, match="Image quality should be a positive number between 1 and 100"):
encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"):
...@@ -440,7 +440,7 @@ def test_encode_jpeg_reference(img_path): ...@@ -440,7 +440,7 @@ def test_encode_jpeg_reference(img_path):
dirname = os.path.dirname(img_path) dirname = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
write_folder = os.path.join(dirname, "jpeg_write") write_folder = os.path.join(dirname, "jpeg_write")
expected_file = os.path.join(write_folder, "{0}_pil.jpg".format(filename)) expected_file = os.path.join(write_folder, f"{filename}_pil.jpg")
img = decode_jpeg(read_file(img_path)) img = decode_jpeg(read_file(img_path))
with open(expected_file, "rb") as f: with open(expected_file, "rb") as f:
...@@ -464,8 +464,8 @@ def test_write_jpeg_reference(img_path, tmpdir): ...@@ -464,8 +464,8 @@ def test_write_jpeg_reference(img_path, tmpdir):
basedir = os.path.dirname(img_path) basedir = os.path.dirname(img_path)
filename, _ = os.path.splitext(os.path.basename(img_path)) filename, _ = os.path.splitext(os.path.basename(img_path))
torch_jpeg = os.path.join(tmpdir, "{0}_torch.jpg".format(filename)) torch_jpeg = os.path.join(tmpdir, f"{filename}_torch.jpg")
pil_jpeg = os.path.join(basedir, "jpeg_write", "{0}_pil.jpg".format(filename)) pil_jpeg = os.path.join(basedir, "jpeg_write", f"{filename}_pil.jpg")
write_jpeg(img, torch_jpeg, quality=75) write_jpeg(img, torch_jpeg, quality=75)
......
...@@ -54,12 +54,12 @@ def _assert_expected(output, name, prec): ...@@ -54,12 +54,12 @@ def _assert_expected(output, name, prec):
if ACCEPT: if ACCEPT:
filename = {os.path.basename(expected_file)} filename = {os.path.basename(expected_file)}
print("Accepting updated output for {}:\n\n{}".format(filename, output)) print(f"Accepting updated output for {filename}:\n\n{output}")
torch.save(output, expected_file) torch.save(output, expected_file)
MAX_PICKLE_SIZE = 50 * 1000 # 50 KB MAX_PICKLE_SIZE = 50 * 1000 # 50 KB
binary_size = os.path.getsize(expected_file) binary_size = os.path.getsize(expected_file)
if binary_size > MAX_PICKLE_SIZE: if binary_size > MAX_PICKLE_SIZE:
raise RuntimeError("The output for {}, is larger than 50kb".format(filename)) raise RuntimeError(f"The output for {filename}, is larger than 50kb")
else: else:
expected = torch.load(expected_file) expected = torch.load(expected_file)
rtol = atol = prec rtol = atol = prec
...@@ -99,12 +99,12 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): ...@@ -99,12 +99,12 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
if not TEST_WITH_SLOW or skip: if not TEST_WITH_SLOW or skip:
# TorchScript is not enabled, skip these tests # TorchScript is not enabled, skip these tests
msg = ( msg = (
"The check_jit_scriptable test for {} was skipped. " f"The check_jit_scriptable test for {nn_module.__class__.__name__} was skipped. "
"This test checks if the module's results in TorchScript " "This test checks if the module's results in TorchScript "
"match eager and that it can be exported. To run these " "match eager and that it can be exported. To run these "
"tests make sure you set the environment variable " "tests make sure you set the environment variable "
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not " "PYTORCH_TEST_WITH_SLOW=1 and that the test is not "
"manually skipped.".format(nn_module.__class__.__name__) "manually skipped."
) )
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
return None return None
...@@ -300,11 +300,11 @@ def test_memory_efficient_densenet(model_fn): ...@@ -300,11 +300,11 @@ def test_memory_efficient_densenet(model_fn):
model1 = model_fn(num_classes=50, memory_efficient=True) model1 = model_fn(num_classes=50, memory_efficient=True)
params = model1.state_dict() params = model1.state_dict()
num_params = sum([x.numel() for x in model1.parameters()]) num_params = sum(x.numel() for x in model1.parameters())
model1.eval() model1.eval()
out1 = model1(x) out1 = model1(x)
out1.sum().backward() out1.sum().backward()
num_grad = sum([x.grad.numel() for x in model1.parameters() if x.grad is not None]) num_grad = sum(x.grad.numel() for x in model1.parameters() if x.grad is not None)
model2 = model_fn(num_classes=50, memory_efficient=False) model2 = model_fn(num_classes=50, memory_efficient=False)
model2.load_state_dict(params) model2.load_state_dict(params)
...@@ -451,8 +451,8 @@ def test_generalizedrcnn_transform_repr(): ...@@ -451,8 +451,8 @@ def test_generalizedrcnn_transform_repr():
# Check integrity of object __repr__ attribute # Check integrity of object __repr__ attribute
expected_string = "GeneralizedRCNNTransform(" expected_string = "GeneralizedRCNNTransform("
_indent = "\n " _indent = "\n "
expected_string += "{0}Normalize(mean={1}, std={2})".format(_indent, image_mean, image_std) expected_string += f"{_indent}Normalize(mean={image_mean}, std={image_std})"
expected_string += "{0}Resize(min_size=({1},), max_size={2}, ".format(_indent, min_size, max_size) expected_string += f"{_indent}Resize(min_size=({min_size},), max_size={max_size}, "
expected_string += "mode='bilinear')\n)" expected_string += "mode='bilinear')\n)"
assert t.__repr__() == expected_string assert t.__repr__() == expected_string
...@@ -541,10 +541,10 @@ def test_segmentation_model(model_fn, dev): ...@@ -541,10 +541,10 @@ def test_segmentation_model(model_fn, dev):
if not full_validation: if not full_validation:
msg = ( msg = (
"The output of {} could only be partially validated. " f"The output of {test_segmentation_model.__name__} could only be partially validated. "
"This is likely due to unit-test flakiness, but you may " "This is likely due to unit-test flakiness, but you may "
"want to do additional manual checks if you made " "want to do additional manual checks if you made "
"significant changes to the codebase.".format(test_segmentation_model.__name__) "significant changes to the codebase."
) )
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
pytest.skip(msg) pytest.skip(msg)
...@@ -638,10 +638,10 @@ def test_detection_model(model_fn, dev): ...@@ -638,10 +638,10 @@ def test_detection_model(model_fn, dev):
if not full_validation: if not full_validation:
msg = ( msg = (
"The output of {} could only be partially validated. " f"The output of {test_detection_model.__name__} could only be partially validated. "
"This is likely due to unit-test flakiness, but you may " "This is likely due to unit-test flakiness, but you may "
"want to do additional manual checks if you made " "want to do additional manual checks if you made "
"significant changes to the codebase.".format(test_detection_model.__name__) "significant changes to the codebase."
) )
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
pytest.skip(msg) pytest.skip(msg)
......
...@@ -78,7 +78,7 @@ class TestONNXExporter: ...@@ -78,7 +78,7 @@ class TestONNXExporter:
ort_session = onnxruntime.InferenceSession(onnx_io.getvalue()) ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
# compute onnxruntime output prediction # compute onnxruntime output prediction
ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs)) ort_inputs = {ort_session.get_inputs()[i].name: inpt for i, inpt in enumerate(inputs)}
ort_outs = ort_session.run(None, ort_inputs) ort_outs = ort_session.run(None, ort_inputs)
for i in range(0, len(outputs)): for i in range(0, len(outputs)):
...@@ -185,7 +185,7 @@ class TestONNXExporter: ...@@ -185,7 +185,7 @@ class TestONNXExporter:
def test_resize_images(self): def test_resize_images(self):
class TransformModule(torch.nn.Module): class TransformModule(torch.nn.Module):
def __init__(self_module): def __init__(self_module):
super(TransformModule, self_module).__init__() super().__init__()
self_module.transform = self._init_test_generalized_rcnn_transform() self_module.transform = self._init_test_generalized_rcnn_transform()
def forward(self_module, images): def forward(self_module, images):
...@@ -200,7 +200,7 @@ class TestONNXExporter: ...@@ -200,7 +200,7 @@ class TestONNXExporter:
def test_transform_images(self): def test_transform_images(self):
class TransformModule(torch.nn.Module): class TransformModule(torch.nn.Module):
def __init__(self_module): def __init__(self_module):
super(TransformModule, self_module).__init__() super().__init__()
self_module.transform = self._init_test_generalized_rcnn_transform() self_module.transform = self._init_test_generalized_rcnn_transform()
def forward(self_module, images): def forward(self_module, images):
...@@ -301,7 +301,7 @@ class TestONNXExporter: ...@@ -301,7 +301,7 @@ class TestONNXExporter:
class RPNModule(torch.nn.Module): class RPNModule(torch.nn.Module):
def __init__(self_module): def __init__(self_module):
super(RPNModule, self_module).__init__() super().__init__()
self_module.rpn = self._init_test_rpn() self_module.rpn = self._init_test_rpn()
def forward(self_module, images, features): def forward(self_module, images, features):
...@@ -335,7 +335,7 @@ class TestONNXExporter: ...@@ -335,7 +335,7 @@ class TestONNXExporter:
def test_multi_scale_roi_align(self): def test_multi_scale_roi_align(self):
class TransformModule(torch.nn.Module): class TransformModule(torch.nn.Module):
def __init__(self): def __init__(self):
super(TransformModule, self).__init__() super().__init__()
self.model = ops.MultiScaleRoIAlign(["feat1", "feat2"], 3, 2) self.model = ops.MultiScaleRoIAlign(["feat1", "feat2"], 3, 2)
self.image_sizes = [(512, 512)] self.image_sizes = [(512, 512)]
...@@ -371,7 +371,7 @@ class TestONNXExporter: ...@@ -371,7 +371,7 @@ class TestONNXExporter:
def test_roi_heads(self): def test_roi_heads(self):
class RoiHeadsModule(torch.nn.Module): class RoiHeadsModule(torch.nn.Module):
def __init__(self_module): def __init__(self_module):
super(RoiHeadsModule, self_module).__init__() super().__init__()
self_module.transform = self._init_test_generalized_rcnn_transform() self_module.transform = self._init_test_generalized_rcnn_transform()
self_module.rpn = self._init_test_rpn() self_module.rpn = self._init_test_rpn()
self_module.roi_heads = self._init_test_roi_heads_faster_rcnn() self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()
......
...@@ -711,7 +711,7 @@ class TestDeformConv: ...@@ -711,7 +711,7 @@ class TestDeformConv:
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation) expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
torch.testing.assert_close( torch.testing.assert_close(
res.to(expected), expected, rtol=tol, atol=tol, msg="\nres:\n{}\nexpected:\n{}".format(res, expected) res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
) )
# no modulation test # no modulation test
...@@ -719,7 +719,7 @@ class TestDeformConv: ...@@ -719,7 +719,7 @@ class TestDeformConv:
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation) expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
torch.testing.assert_close( torch.testing.assert_close(
res.to(expected), expected, rtol=tol, atol=tol, msg="\nres:\n{}\nexpected:\n{}".format(res, expected) res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
) )
def test_wrong_sizes(self): def test_wrong_sizes(self):
......
...@@ -115,7 +115,7 @@ class TestConvertImageDtype: ...@@ -115,7 +115,7 @@ class TestConvertImageDtype:
output_image, output_image,
rtol=0.0, rtol=0.0,
atol=1e-6, atol=1e-6,
msg="{} vs {}".format(output_image_script, output_image), msg=f"{output_image_script} vs {output_image}",
) )
actual_min, actual_max = output_image.tolist() actual_min, actual_max = output_image.tolist()
...@@ -369,7 +369,7 @@ def test_resize(height, width, osize, max_size): ...@@ -369,7 +369,7 @@ def test_resize(height, width, osize, max_size):
t = transforms.Resize(osize, max_size=max_size) t = transforms.Resize(osize, max_size=max_size)
result = t(img) result = t(img)
msg = "{}, {} - {} - {}".format(height, width, osize, max_size) msg = f"{height}, {width} - {osize} - {max_size}"
osize = osize[0] if isinstance(osize, (list, tuple)) else osize osize = osize[0] if isinstance(osize, (list, tuple)) else osize
# If size is an int, smaller edge of the image will be matched to this number. # If size is an int, smaller edge of the image will be matched to this number.
# i.e, if height > width, then image will be rescaled to (size * height / width, size). # i.e, if height > width, then image will be rescaled to (size * height / width, size).
...@@ -470,11 +470,11 @@ class TestPad: ...@@ -470,11 +470,11 @@ class TestPad:
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
img = transforms.ToPILImage()(torch.ones(3, height, width)) img = transforms.ToPILImage()(torch.ones(3, height, width))
padding = tuple([random.randint(1, 20) for _ in range(2)]) padding = tuple(random.randint(1, 20) for _ in range(2))
output = transforms.Pad(padding)(img) output = transforms.Pad(padding)(img)
assert output.size == (width + padding[0] * 2, height + padding[1] * 2) assert output.size == (width + padding[0] * 2, height + padding[1] * 2)
padding = tuple([random.randint(1, 20) for _ in range(4)]) padding = tuple(random.randint(1, 20) for _ in range(4))
output = transforms.Pad(padding)(img) output = transforms.Pad(padding)(img)
assert output.size[0] == width + padding[0] + padding[2] assert output.size[0] == width + padding[0] + padding[2]
assert output.size[1] == height + padding[1] + padding[3] assert output.size[1] == height + padding[1] + padding[3]
...@@ -1703,7 +1703,7 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): ...@@ -1703,7 +1703,7 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height):
assert_equal( assert_equal(
output_tensor, output_tensor,
output_pil, output_pil,
msg="image_size: {} crop_size: {}".format(input_image_size, crop_size), msg=f"image_size: {input_image_size} crop_size: {crop_size}",
) )
# Check if content in center of both image and cropped output is same. # Check if content in center of both image and cropped output is same.
...@@ -2029,9 +2029,9 @@ class TestAffine: ...@@ -2029,9 +2029,9 @@ class TestAffine:
np_result = np.array(result) np_result = np.array(result)
n_diff_pixels = np.sum(np_result != true_result) / 3 n_diff_pixels = np.sum(np_result != true_result) / 3
# Accept 3 wrong pixels # Accept 3 wrong pixels
error_msg = "angle={}, translate={}, scale={}, shear={}\n".format( error_msg = (
angle, translate, scale, shear f"angle={angle}, translate={translate}, scale={scale}, shear={shear}\nn diff pixels={n_diff_pixels}\n"
) + "n diff pixels={}\n".format(n_diff_pixels) )
assert n_diff_pixels < 3, error_msg assert n_diff_pixels < 3, error_msg
def test_transformation_discrete(self, pil_image, input_img): def test_transformation_discrete(self, pil_image, input_img):
...@@ -2121,12 +2121,8 @@ def test_random_affine(): ...@@ -2121,12 +2121,8 @@ def test_random_affine():
for _ in range(100): for _ in range(100):
angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, img_size=img.size) angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, img_size=img.size)
assert -10 < angle < 10 assert -10 < angle < 10
assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5, "{} vs {}".format( assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5
translations[0], img.size[0] * 0.5 assert -img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5
)
assert -img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5, "{} vs {}".format(
translations[1], img.size[1] * 0.5
)
assert 0.7 < scale < 1.3 assert 0.7 < scale < 1.3
assert -10 < shear[0] < 10 assert -10 < shear[0] < 10
assert -20 < shear[1] < 40 assert -20 < shear[1] < 40
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment