Unverified Commit d367a01a authored by Jirka Borovec's avatar Jirka Borovec Committed by GitHub
Browse files

Use f-strings almost everywhere, and other cleanups by applying pyupgrade (#4585)


Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 50dfe207
......@@ -34,7 +34,6 @@ A diff output is produced and a sensible exit code is returned.
import argparse
import difflib
import fnmatch
import io
import multiprocessing
import os
import signal
......@@ -87,20 +86,20 @@ def list_files(files, recursive=False, extensions=None, exclude=None):
def make_diff(file, original, reformatted):
return list(
difflib.unified_diff(
original, reformatted, fromfile="{}\t(original)".format(file), tofile="{}\t(reformatted)".format(file), n=3
original, reformatted, fromfile=f"{file}\t(original)", tofile=f"{file}\t(reformatted)", n=3
)
)
class DiffError(Exception):
def __init__(self, message, errs=None):
super(DiffError, self).__init__(message)
super().__init__(message)
self.errs = errs or []
class UnexpectedError(Exception):
def __init__(self, message, exc=None):
super(UnexpectedError, self).__init__(message)
super().__init__(message)
self.formatted_traceback = traceback.format_exc()
self.exc = exc
......@@ -112,14 +111,14 @@ def run_clang_format_diff_wrapper(args, file):
except DiffError:
raise
except Exception as e:
raise UnexpectedError("{}: {}: {}".format(file, e.__class__.__name__, e), e)
raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e)
def run_clang_format_diff(args, file):
try:
with io.open(file, "r", encoding="utf-8") as f:
with open(file, encoding="utf-8") as f:
original = f.readlines()
except IOError as exc:
except OSError as exc:
raise DiffError(str(exc))
invocation = [args.clang_format_executable, file]
......@@ -145,7 +144,7 @@ def run_clang_format_diff(args, file):
invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, encoding="utf-8"
)
except OSError as exc:
raise DiffError("Command '{}' failed to start: {}".format(subprocess.list2cmdline(invocation), exc))
raise DiffError(f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}")
proc_stdout = proc.stdout
proc_stderr = proc.stderr
......@@ -203,7 +202,7 @@ def print_trouble(prog, message, use_colors):
error_text = "error:"
if use_colors:
error_text = bold_red(error_text)
print("{}: {} {}".format(prog, error_text, message), file=sys.stderr)
print(f"{prog}: {error_text} {message}", file=sys.stderr)
def main():
......@@ -216,7 +215,7 @@ def main():
)
parser.add_argument(
"--extensions",
help="comma separated list of file extensions (default: {})".format(DEFAULT_EXTENSIONS),
help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})",
default=DEFAULT_EXTENSIONS,
)
parser.add_argument("-r", "--recursive", action="store_true", help="run recursively over directories")
......@@ -227,7 +226,7 @@ def main():
metavar="N",
type=int,
default=0,
help="run N clang-format jobs in parallel" " (default number of cpus + 1)",
help="run N clang-format jobs in parallel (default number of cpus + 1)",
)
parser.add_argument(
"--color", default="auto", choices=["auto", "always", "never"], help="show colored diff (default: auto)"
......@@ -238,7 +237,7 @@ def main():
metavar="PATTERN",
action="append",
default=[],
help="exclude paths matching the given glob-like pattern(s)" " from recursive search",
help="exclude paths matching the given glob-like pattern(s) from recursive search",
)
args = parser.parse_args()
......@@ -263,7 +262,7 @@ def main():
colored_stdout = sys.stdout.isatty()
colored_stderr = sys.stderr.isatty()
version_invocation = [args.clang_format_executable, str("--version")]
version_invocation = [args.clang_format_executable, "--version"]
try:
subprocess.check_call(version_invocation, stdout=DEVNULL)
except subprocess.CalledProcessError as e:
......@@ -272,7 +271,7 @@ def main():
except OSError as e:
print_trouble(
parser.prog,
"Command '{}' failed to start: {}".format(subprocess.list2cmdline(version_invocation), e),
f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}",
use_colors=colored_stderr,
)
return ExitStatus.TROUBLE
......
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: check-docstring-first
- id: check-toml
- id: check-yaml
exclude: packaging/.*
- id: end-of-file-fixer
# - repo: https://github.com/asottile/pyupgrade
# rev: v2.29.0
# hooks:
# - id: pyupgrade
# args: [--py36-plus]
# name: Upgrade code
- repo: https://github.com/omnilib/ufmt
rev: v1.3.0
hooks:
......@@ -6,16 +22,9 @@ repos:
additional_dependencies:
- black == 21.9b0
- usort == 0.6.4
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
args: [--config=setup.cfg]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: check-docstring-first
- id: check-toml
- id: check-yaml
exclude: packaging/.*
- id: end-of-file-fixer
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# PyTorch documentation build configuration file, created by
# sphinx-quickstart on Fri Dec 23 13:31:47 2016.
......
......@@ -125,7 +125,7 @@ res_scripted = scripted_predictor(batch)
import json
with open(Path('assets') / 'imagenet_class_index.json', 'r') as labels_file:
with open(Path('assets') / 'imagenet_class_index.json') as labels_file:
labels = json.load(labels_file)
for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
......
......@@ -137,7 +137,7 @@ def example_read_video(video_object, start=0, end=None, read_video=True, read_au
if end < start:
raise ValueError(
"end time should be larger than start time, got "
"start time={} and end time={}".format(start, end)
f"start time={start} and end time={end}"
)
video_frames = torch.empty(0)
......
# -*- coding: utf-8 -*-
"""Helper script to package wheels and relocate binaries."""
import glob
......@@ -157,7 +155,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
rename and copy them into the wheel while updating their respective rpaths.
"""
print("Relocating {0}".format(binary))
print(f"Relocating {binary}")
binary_path = osp.join(output_library, binary)
ld_tree = lddtree(binary_path)
......@@ -173,12 +171,12 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
print(library)
if library_info["path"] is None:
print("Omitting {0}".format(library))
print(f"Omitting {library}")
continue
if library in ALLOWLIST:
# Omit glibc/gcc/system libraries
print("Omitting {0}".format(library))
print(f"Omitting {library}")
continue
parent_dependencies = binary_dependencies.get(parent, [])
......@@ -201,7 +199,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
if library != binary:
library_path = binary_paths[library]
new_library_path = patch_new_path(library_path, new_libraries_path)
print("{0} -> {1}".format(library, new_library_path))
print(f"{library} -> {new_library_path}")
shutil.copyfile(library_path, new_library_path)
new_names[library] = new_library_path
......@@ -214,7 +212,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
new_library_name = new_names[library]
for dep in library_dependencies:
new_dep = osp.basename(new_names[dep])
print("{0}: {1} -> {2}".format(library, dep, new_dep))
print(f"{library}: {dep} -> {new_dep}")
subprocess.check_output(
[patchelf, "--replace-needed", dep, new_dep, new_library_name], cwd=new_libraries_path
)
......@@ -228,7 +226,7 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary):
library_dependencies = binary_dependencies[binary]
for dep in library_dependencies:
new_dep = osp.basename(new_names[dep])
print("{0}: {1} -> {2}".format(binary, dep, new_dep))
print(f"{binary}: {dep} -> {new_dep}")
subprocess.check_output([patchelf, "--replace-needed", dep, new_dep, binary], cwd=output_library)
print("Update library rpath")
......@@ -244,7 +242,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
Given a shared library, find the transitive closure of its dependencies,
rename and copy them into the wheel.
"""
print("Relocating {0}".format(binary))
print(f"Relocating {binary}")
binary_path = osp.join(output_library, binary)
library_dlls = find_dll_dependencies(dumpbin, binary_path)
......@@ -255,18 +253,18 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
while binary_queue != []:
library, parent = binary_queue.pop(0)
if library in WINDOWS_ALLOWLIST or library.startswith("api-ms-win"):
print("Omitting {0}".format(library))
print(f"Omitting {library}")
continue
library_path = find_program(library)
if library_path is None:
print("{0} not found".format(library))
print(f"{library} not found")
continue
if osp.basename(osp.dirname(library_path)) == "system32":
continue
print("{0}: {1}".format(library, library_path))
print(f"{library}: {library_path}")
parent_dependencies = binary_dependencies.get(parent, [])
parent_dependencies.append(library)
binary_dependencies[parent] = parent_dependencies
......@@ -284,7 +282,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary):
if library != binary:
library_path = binary_paths[library]
new_library_path = osp.join(package_dir, library)
print("{0} -> {1}".format(library, new_library_path))
print(f"{library} -> {new_library_path}")
shutil.copyfile(library_path, new_library_path)
......@@ -300,16 +298,16 @@ def compress_wheel(output_dir, wheel, wheel_dir, wheel_name):
full_file = osp.join(root, this_file)
rel_file = osp.relpath(full_file, output_dir)
if full_file == record_file:
f.write("{0},,\n".format(rel_file))
f.write(f"{rel_file},,\n")
else:
digest, size = rehash(full_file)
f.write("{0},{1},{2}\n".format(rel_file, digest, size))
f.write(f"{rel_file},{digest},{size}\n")
print("Compressing wheel")
base_wheel_name = osp.join(wheel_dir, wheel_name)
shutil.make_archive(base_wheel_name, "zip", output_dir)
os.remove(wheel)
shutil.move("{0}.zip".format(base_wheel_name), wheel)
shutil.move(f"{base_wheel_name}.zip", wheel)
shutil.rmtree(output_dir)
......@@ -317,9 +315,7 @@ def patch_linux():
# Get patchelf location
patchelf = find_program("patchelf")
if patchelf is None:
raise FileNotFoundError(
"Patchelf was not found in the system, please" " make sure that is available on the PATH."
)
raise FileNotFoundError("Patchelf was not found in the system, please make sure that is available on the PATH.")
# Find wheel
print("Finding wheels...")
......@@ -338,7 +334,7 @@ def patch_linux():
print("Unzipping wheel...")
wheel_file = osp.basename(wheel)
wheel_dir = osp.dirname(wheel)
print("{0}".format(wheel_file))
print(f"{wheel_file}")
wheel_name, _ = osp.splitext(wheel_file)
unzip_file(wheel, output_dir)
......@@ -355,9 +351,7 @@ def patch_win():
# Get dumpbin location
dumpbin = find_program("dumpbin")
if dumpbin is None:
raise FileNotFoundError(
"Dumpbin was not found in the system, please" " make sure that is available on the PATH."
)
raise FileNotFoundError("Dumpbin was not found in the system, please make sure that is available on the PATH.")
# Find wheel
print("Finding wheels...")
......@@ -376,7 +370,7 @@ def patch_win():
print("Unzipping wheel...")
wheel_file = osp.basename(wheel)
wheel_dir = osp.dirname(wheel)
print("{0}".format(wheel_file))
print(f"{wheel_file}")
wheel_name, _ = osp.splitext(wheel_file)
unzip_file(wheel, output_dir)
......
......@@ -26,7 +26,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
header = "Epoch: [{}]".format(epoch)
header = f"Epoch: [{epoch}]"
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time()
image, target = image.to(device), target.to(device)
......@@ -121,7 +121,7 @@ def load_data(traindir, valdir, args):
cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
print(f"Loading dataset_train from {cache_path}")
dataset, _ = torch.load(cache_path)
else:
auto_augment_policy = getattr(args, "auto_augment", None)
......@@ -136,7 +136,7 @@ def load_data(traindir, valdir, args):
),
)
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
print(f"Saving dataset_train to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)
......@@ -145,7 +145,7 @@ def load_data(traindir, valdir, args):
cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path)
else:
if not args.weights:
......@@ -162,7 +162,7 @@ def load_data(traindir, valdir, args):
preprocessing,
)
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
print(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
......@@ -270,8 +270,8 @@ def main(args):
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
else:
raise RuntimeError(
"Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported.".format(args.lr_scheduler)
f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported."
)
if args.lr_warmup_epochs > 0:
......@@ -285,7 +285,7 @@ def main(args):
)
else:
raise RuntimeError(
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant " "are supported."
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
......@@ -351,12 +351,12 @@ def main(args):
}
if model_ema:
checkpoint["model_ema"] = model_ema.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))
print(f"Training time {total_time_str}")
def get_args_parser(add_help=True):
......
......@@ -20,7 +20,7 @@ def main(args):
print(args)
if args.post_training_quantize and args.distributed:
raise RuntimeError("Post training quantization example should not be performed " "on distributed mode")
raise RuntimeError("Post training quantization example should not be performed on distributed mode")
# Set backend engine to ensure that quantized model runs on the correct kernels
if args.backend not in torch.backends.quantized.supported_engines:
......@@ -141,13 +141,13 @@ def main(args):
"epoch": epoch,
"args": args,
}
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
print("Saving models after epoch ", epoch)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))
print(f"Training time {total_time_str}")
def get_args_parser(add_help=True):
......
......@@ -39,13 +39,13 @@ class RandomMixup(torch.nn.Module):
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif not batch.is_floating_point():
raise TypeError("Batch dtype should be a float tensor. Got {}.".format(batch.dtype))
elif target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace:
batch = batch.clone()
......@@ -115,13 +115,13 @@ class RandomCutmix(torch.nn.Module):
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif not batch.is_floating_point():
raise TypeError("Batch dtype should be a float tensor. Got {}.".format(batch.dtype))
elif target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace:
batch = batch.clone()
......
......@@ -10,7 +10,7 @@ import torch
import torch.distributed as dist
class SmoothedValue(object):
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
......@@ -65,7 +65,7 @@ class SmoothedValue(object):
)
class MetricLogger(object):
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
......@@ -82,12 +82,12 @@ class MetricLogger(object):
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
......@@ -152,7 +152,7 @@ class MetricLogger(object):
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {}".format(header, total_time_str))
print(f"{header} Total time: {total_time_str}")
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
......@@ -270,7 +270,7 @@ def init_distributed_mode(args):
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
......@@ -307,8 +307,7 @@ def average_checkpoints(inputs):
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
"For checkpoint {}, expected list of params: {}, "
"but found: {}".format(f, params_keys, model_params_keys)
f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
)
for k in params_keys:
p = model_params[k]
......
......@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
class FilterAndRemapCocoCategories(object):
class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap
......@@ -44,7 +44,7 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return masks
class ConvertCocoPolysToMask(object):
class ConvertCocoPolysToMask:
def __call__(self, image, target):
w, h = image.size
......@@ -205,11 +205,11 @@ def get_coco_api_from_dataset(dataset):
class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms):
super(CocoDetection, self).__init__(img_folder, ann_file)
super().__init__(img_folder, ann_file)
self._transforms = transforms
def __getitem__(self, idx):
img, target = super(CocoDetection, self).__getitem__(idx)
img, target = super().__getitem__(idx)
image_id = self.ids[idx]
target = dict(image_id=image_id, annotations=target)
if self._transforms is not None:
......
......@@ -13,7 +13,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = "Epoch: [{}]".format(epoch)
header = f"Epoch: [{epoch}]"
lr_scheduler = None
if epoch == 0:
......@@ -39,7 +39,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
loss_value = losses_reduced.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(f"Loss is {loss_value}, stopping training")
print(loss_dict_reduced)
sys.exit(1)
......
......@@ -36,9 +36,7 @@ class GroupedBatchSampler(BatchSampler):
def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler):
raise ValueError(
"sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
raise ValueError(f"sampler should be an instance of torch.utils.data.Sampler, but got sampler={sampler}")
self.sampler = sampler
self.group_ids = group_ids
self.batch_size = batch_size
......@@ -193,6 +191,6 @@ def create_aspect_ratio_groups(dataset, k=0):
# count number of elements per group
counts = np.unique(groups, return_counts=True)[1]
fbins = [0] + bins + [np.inf]
print("Using {} as bins for aspect ratio quantization".format(fbins))
print("Count of instances per bin: {}".format(counts))
print(f"Using {fbins} as bins for aspect ratio quantization")
print(f"Count of instances per bin: {counts}")
return groups
......@@ -65,7 +65,7 @@ def get_args_parser(add_help=True):
"--lr",
default=0.02,
type=float,
help="initial learning rate, 0.02 is the default value for training " "on 8 gpus and 2 images_per_gpu",
help="initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu",
)
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument(
......@@ -197,8 +197,7 @@ def main(args):
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
else:
raise RuntimeError(
"Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR "
"are supported.".format(args.lr_scheduler)
f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
)
if args.resume:
......@@ -227,7 +226,7 @@ def main(args):
"args": args,
"epoch": epoch,
}
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
# evaluate after every epoch
......@@ -235,7 +234,7 @@ def main(args):
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))
print(f"Training time {total_time_str}")
if __name__ == "__main__":
......
......@@ -17,7 +17,7 @@ def _flip_coco_person_keypoints(kps, width):
return flipped_data
class Compose(object):
class Compose:
def __init__(self, transforms):
self.transforms = transforms
......@@ -103,7 +103,7 @@ class RandomIoUCrop(nn.Module):
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension()))
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)
......@@ -171,7 +171,7 @@ class RandomZoomOut(nn.Module):
self.fill = fill
self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError("Invalid canvas side range provided {}.".format(side_range))
raise ValueError(f"Invalid canvas side range provided {side_range}.")
self.p = p
@torch.jit.unused
......@@ -185,7 +185,7 @@ class RandomZoomOut(nn.Module):
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension()))
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)
......@@ -244,7 +244,7 @@ class RandomPhotometricDistort(nn.Module):
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension()))
raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
elif image.ndimension() == 2:
image = image.unsqueeze(0)
......
......@@ -8,7 +8,7 @@ import torch
import torch.distributed as dist
class SmoothedValue(object):
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
......@@ -110,7 +110,7 @@ def reduce_dict(input_dict, average=True):
return reduced_dict
class MetricLogger(object):
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
......@@ -127,12 +127,12 @@ class MetricLogger(object):
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
......@@ -197,7 +197,7 @@ class MetricLogger(object):
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
print(f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")
def collate_fn(batch):
......@@ -274,7 +274,7 @@ def init_distributed_mode(args):
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
......
......@@ -9,7 +9,7 @@ from pycocotools import mask as coco_mask
from transforms import Compose
class FilterAndRemapCocoCategories(object):
class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap
......@@ -41,7 +41,7 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return masks
class ConvertCocoPolysToMask(object):
class ConvertCocoPolysToMask:
def __call__(self, image, anno):
w, h = image.size
segmentations = [obj["segmentation"] for obj in anno]
......
......@@ -66,7 +66,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
header = "Epoch: [{}]".format(epoch)
header = f"Epoch: [{epoch}]"
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image, target = image.to(device), target.to(device)
output = model(image)
......@@ -152,8 +152,7 @@ def main(args):
)
else:
raise RuntimeError(
"Invalid warmup lr method '{}'. Only linear and constant "
"are supported.".format(args.lr_warmup_method)
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
......@@ -188,12 +187,12 @@ def main(args):
"epoch": epoch,
"args": args,
}
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))
print(f"Training time {total_time_str}")
def get_args_parser(add_help=True):
......
......@@ -16,7 +16,7 @@ def pad_if_smaller(img, size, fill=0):
return img
class Compose(object):
class Compose:
def __init__(self, transforms):
self.transforms = transforms
......@@ -26,7 +26,7 @@ class Compose(object):
return image, target
class RandomResize(object):
class RandomResize:
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size is None:
......@@ -40,7 +40,7 @@ class RandomResize(object):
return image, target
class RandomHorizontalFlip(object):
class RandomHorizontalFlip:
def __init__(self, flip_prob):
self.flip_prob = flip_prob
......@@ -51,7 +51,7 @@ class RandomHorizontalFlip(object):
return image, target
class RandomCrop(object):
class RandomCrop:
def __init__(self, size):
self.size = size
......@@ -64,7 +64,7 @@ class RandomCrop(object):
return image, target
class CenterCrop(object):
class CenterCrop:
def __init__(self, size):
self.size = size
......@@ -90,7 +90,7 @@ class ConvertImageDtype:
return image, target
class Normalize(object):
class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
......
......@@ -8,7 +8,7 @@ import torch
import torch.distributed as dist
class SmoothedValue(object):
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
......@@ -67,7 +67,7 @@ class SmoothedValue(object):
)
class ConfusionMatrix(object):
class ConfusionMatrix:
def __init__(self, num_classes):
self.num_classes = num_classes
self.mat = None
......@@ -101,15 +101,15 @@ class ConfusionMatrix(object):
def __str__(self):
acc_global, acc, iu = self.compute()
return ("global correct: {:.1f}\n" "average row correct: {}\n" "IoU: {}\n" "mean IoU: {:.1f}").format(
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
acc_global.item() * 100,
["{:.1f}".format(i) for i in (acc * 100).tolist()],
["{:.1f}".format(i) for i in (iu * 100).tolist()],
[f"{i:.1f}" for i in (acc * 100).tolist()],
[f"{i:.1f}" for i in (iu * 100).tolist()],
iu.mean().item() * 100,
)
class MetricLogger(object):
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
......@@ -126,12 +126,12 @@ class MetricLogger(object):
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
......@@ -196,7 +196,7 @@ class MetricLogger(object):
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {}".format(header, total_time_str))
print(f"{header} Total time: {total_time_str}")
def cat_list(images, fill_value=0):
......@@ -287,7 +287,7 @@ def init_distributed_mode(args):
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment